<a href="https://colab.research.google.com/github/jenlee04/olm/blob/main/olm_3500final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Mount drive.





In [53]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Install packages + things needed.


In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import os
from pathlib import Path
import copy

In [55]:
# load in datasets
off_tes_path = '/content/drive/My Drive/3500final/data/offline_testing/'
off_tra_path = '/content/drive/My Drive/3500final/data/offline_training'
on_nut_path = '/content/drive/My Drive/3500final/data/online_nuts/'
on_spice_path = '/content/drive/My Drive/3500final/data/online_spices/'
offline_testing_data = []
offline_testing_labels = []
offline_training_data = []
offline_training_labels = []
online_nuts_data = []
online_nuts_labels = []
online_spices_data = []
online_spices_labels = []

all_odors_in_order = set()
for root, dirs, files in os.walk(off_tes_path):
    for filename in files:
        if filename.endswith(".csv"):
            df = os.path.join(root, filename)
            if os.path.getsize(df) == 0:
              print(f"skipped empty file: {df}")
              continue
            try:
              df2 = pd.read_csv(df, header = 0)
              odor_name = Path(df).parent.name
              offline_testing_data.append(df2)
              offline_testing_labels.append(odor_name)
              all_odors_in_order.add(odor_name)
            except Exception as e:
              print(f"error reading {df}: {e}")


for root, dirs, files in os.walk(off_tra_path):
    for filename in files:
        if filename.endswith(".csv"):
            df = os.path.join(root, filename)
            if os.path.getsize(df) == 0:
              print(f"skipped empty file: {df}")
              continue
            try:
              df2 = pd.read_csv(df, header = 0)
              odor_name = Path(df).parent.name
              offline_training_data.append(df2)
              offline_training_labels.append(odor_name)
              all_odors_in_order.add(odor_name)
            except Exception as e:
              print(f"error reading {df}: {e}")


for root, dirs, files in os.walk(on_nut_path):
    for filename in files:
        if filename.endswith(".csv"):
            df = os.path.join(root, filename)
            if os.path.getsize(df) == 0:
              print(f"skipped empty file: {df}")
              continue
            try:
              df2 = pd.read_csv(df, header = 0)
              odor_name = Path(df).parent.name
              online_nuts_data.append(df2)
              online_nuts_labels.append(odor_name)
              all_odors_in_order.add(odor_name)
            except Exception as e:
              print(f"error reading {df}: {e}")

for root, dirs, files in os.walk(on_spice_path):
    for filename in files:
        if filename.endswith(".csv"):
            df = os.path.join(root, filename)
            if os.path.getsize(df) == 0:
              print(f"skipped empty file: {df}")
              continue
            try:
              df2 = pd.read_csv(df, header = 0)
              odor_name = Path(df).parent.name
              online_spices_data.append(df2)
              online_spices_labels.append(odor_name)
              all_odors_in_order.add(odor_name)
            except Exception as e:
              print(f"error reading {df}: {e}")
odor_to_label = {name: idx for idx, name in enumerate(sorted(all_odors_in_order))}
label_to_odor = {idx: name for name, idx in odor_to_label.items()}

# integer vals for labels
offline_testing_labels = [odor_to_label[name] for name in offline_testing_labels]
offline_training_labels = [odor_to_label[name] for name in offline_training_labels]
online_nuts_labels = [odor_to_label[name] for name in online_nuts_labels]
online_spices_labels = [odor_to_label[name] for name in online_spices_labels]



#print(offline_testing_data)
#print(offline_training_data)
#print(online_nuts_data)
#print(online_spices_data)

Olfactory receptors.

In [56]:
class OlfactoryReceptorLayer(nn.Module):
  """
  Purpose: Should mimic olfactory receptors (ORNs) in that certain olfactory receptors
  activate for specific odors over others (combinatorial coding)

  params
  input_dim: number of the input features from sensors
  n_receptors: number of artifical olfactory receptors

  attributes
  self.layer: linear transformation occurs, where it takes in the input dimension
  (so the number of "features") and the number of receptors - "output dimension",
  weight matrix is FROZEN (taken from std normal Gaussian). we should not see
  any learning occur, we just want to see that the olfactory receptor neurons
  fire in a particular way for specific odorants as we observe in biology.

  methods
  forward: takes gsdata (gas sensor data vector), returns receptor activation with
  shape [batch, time, n_receptors] after using ReLU

  returns
  receptor activations [b,t,n_receptors]
  """
  def __init__(self, input_dim = 12, n_receptors = 100):
    super().__init__()
    self.input_dim = input_dim
    self.n_receptors = n_receptors

    # linear feedforward network
    self.layer = nn.Linear(input_dim, n_receptors, bias = False)

    # weight initialization and freeze
    nn.init.normal_(self.layer.weight, mean = 0.0, std = 1.0)
    self.layer.weight.requires_grad = False


  def forward(self, gsdata):
    # x takes size [batch, time, dimensions]
    return torch.relu(self.layer(gsdata))




Now we need to project these activations to something that can learn from them.

In [57]:
class OlfactoryBulb(nn.Module):
  """
  Purpose: learn from the specific activations of the receptors and correlate
  them with specific odorants/features of the gas sensor input that is fed into
  the receptors. We will be ignoring periglomerular cells (PGs) here
  - no fine tuning. * could be added in a later version

  params
  n_receptors: "input dimension" -> specific receptor activations
  used as input here.
  n_mitral: "output dimension" -> number of mitral cells to simulate
  lateral_strength: strength of the lateral inhibition between the mitral cells
  (val can range between 0 to 1)

  attributes
  self.alpha: normalization rate parameter
  self.eps: epsilon val. that is a small constant, prevents div0
  self.gru: Gated Recurrent Unit, simplified ver. of LSTM. used for processing
  time series of receptor signals

  methods
  forward: takes in the receptor activity ([b,t, n_receptors]), returns the mitral
  cell activations ([b, t, n_mitral]) after lateral inhibition and normalization
  are applied

  returns
  mitral cell activations [b, t, n_mitral]
  """
  def __init__(self, n_receptors, n_mitral, lateral_strength = 0.3):
    super().__init__()
    self.alpha = 0.1
    self.eps = 1e-6
    self.lateral_strength = lateral_strength

    # temporal processing - processes the receptor inputs over time
    self.gru = nn.GRU(n_receptors, n_mitral, batch_first = True)

    # lateral inhib. matrix -> frozen, does not learn, is symmetric and each
    # mitral cell inhibits other based on distance
    # closer MCs inhibit each other; cannot inhibit further ones
    lateral_w = torch.zeros(n_mitral, n_mitral)
    for i in range(n_mitral):
      for j in range(n_mitral):
        if i != j:
          dis = abs(i-j) / n_mitral
          lateral_w[i,j] = torch.exp(torch.tensor(-dis * 5))
    self.register_buffer('lateral_w', lateral_w)


  def forward(self, r):
    m, _ = self.gru(r) # takes size [B, T, n_mitral]

    # lateral inhib. application
    B, T, M = m.shape
    m_flat = m.reshape(B * T, M)
    inhib = F.linear(torch.relu(m_flat), self.lateral_strength * self.lateral_w)
    m_inhib = m_flat - inhib
    m_inhib = m_inhib.reshape(B, T, M)

    # normalization
    norm = torch.sqrt((m_inhib**2).mean(dim = -1, keepdim = True) + self.eps)
    return m_inhib/(1+self.alpha * norm)

Now interpret via Piriform cortex, which uses Hebbian plasticity.

In [58]:
class Piriform(nn.Module):
  """
  Purpose: sparse associations, Hebbian learning occurs here. Simulates piriform cortex,
  odor patterns should be associated with odor identities. Weight updates.

  params
  n_mitral: number of mitral cell inputs from OB
  n_piric: number of piriform cortex neurons
  thresh = activation threshold, used for sparse coding

  attributes
  self.w: (nonfrozen) weights with shape [n_piric, n_mitral]
  self.thresh: sparsity threshold (if > self.thresh, activates)

  methods
  forward: takes mitral cell activations and returns piriform cortex activations
  with shape [batch, n_piric], first takes temporal mean, then linear transformation
  and thresh. difference, then ReLU

  hebbian_update: updates weights via Hebbian learning, takes the mitral cell activations,
  piriform activations, and a hebbian learning rate (usually small like 1e-4)
  """
  def __init__(self, n_mitral, n_piric, thresh = 0.1):
    super().__init__()

    # estalbish the weights
    self.w = nn.Parameter(torch.randn(n_piric, n_mitral)* 0.1)
    self.thresh = thresh


  # same idea as in receptor layer
  def forward(self, m):
    # takes in mitral cell activity and the weights
    mean_m= m.mean(dim = 1)
    activ = torch.relu(F.linear(mean_m, self.w) - self.thresh)
    return activ

  # no need for gradient tracking here, no backprop
  @torch.no_grad()
  def hebbian_update(self, m, activ, lr = 1e-4):
    # hebbian rule is that the change in weights is proportional to the feed forward
    # activity multiplied by the mitral cell activity, matrix multiplicatoin
    # y.T : (n_pc, B) and m: (B, n_mitral), gives n_pc by n_mitral weight matrix
    mean_m = m.mean(dim =1)
    delw = (activ.unsqueeze(-1)*mean_m.unsqueeze(1)).mean(dim=0) # hebbian weight changes

    # learning rate applied and update weights
    self.w.data += lr * delw


Full model

In [59]:
class OLM(nn.Module):
  """
  Purpose: everything put together! Should go from the receptor to bulb to piriform to
  an output head that classifies the odor.

  params
  input_dim: number of dimensions for input (12 in this case)
  n_receptors: number of olfactory receptors
  n_mitral: number of mitral cells
  n_piric: number of piriform cortex cells
  n_class: number of odor classes for classification (50 in this case)
  lateral_strength: lateral inhibition val for olfactory bulb

  attributes
  self.receptors: the olfactory receptors
  self.bulb: olfactory bulb
  self.piriform: piriform cortex
  self.classifier: one layer (unfrozen weights) network with supervision, classifies
  odors

  methods
  forward: takes sensor input and at the very end outputs the corresponding odor
  label

  returns
  logits - predictions for the odor identity [batch, n_class]
  bulb - mitral cell activations [batch, time, n_piric]
  pc - piriform cortex activations [batch, n_piric]
  """
  def __init__(self, input_dim, n_receptors, n_mitral, n_piric, n_class, lateral_strength = 0.3):
    super().__init__()
    self.receptors = OlfactoryReceptorLayer(input_dim, n_receptors)
    self.bulb = OlfactoryBulb(n_receptors, n_mitral, lateral_strength)
    self.piriform = Piriform(n_mitral, n_piric)
    self.classifier = nn.Linear(n_piric, n_class)

  def forward(self, inp):
    rec = self.receptors(inp)
    bulb = self.bulb(rec)
    pc = self.piriform(bulb)
    logits = self.classifier(pc)

    return logits, bulb, pc



Now make a class so we can load in the datasets from earlier.

In [60]:
class SmellNetData(torch.utils.data.Dataset):
  """
  Purpose: Used for dataset loading and preprocessing via FOTD when needed.

  params
  dataframes: the list of the gas sensor data csv files
  labels: integer labels that correspond to each dataframe
  mean: mean vals for normalization (one per sensor data)
  std: std val for normalization
  use_fotd: takes True/False, toggle to apply FOTD
  p: lag parameter for FOTD (p = 25 here)
  max_len: maximum sequence length, needed to pad/truncate because not every csv
  file has the same amount of timesteps/rows

  attributes
  self.sensor_cols: names of the 12 features of the input data
  self.max_len: fixed seq length

  methods
  fotd: applies FOTD transformation, input is the time series (time, features)
  and output is the derivative time series
  __len__(): returns number of samples in the dataset
  __getitem__(idx): retrieve and preprocess a single sample at index idx
  """
  def __init__(self, dataframes, labels, mean, std, use_fotd = False, p = 25, max_len = None):
    self.dataframes = dataframes
    self.labels = labels
    self.use_fotd = use_fotd
    self.p = p
    self.mean = mean
    self.std = std


    self.max_len = max_len or max(len(df) for df in dataframes)

    self.sensor_cols = ["NO2", "C2H5OH", "VOC", "CO", "Alcohol",
                        "LPG", "Benzene", "Temperature", "Pressure",
                        "Humidity", "Gas_Resistance", "Altitude"]
  def fotd(self, x):
    dx = x[self.p:] - x[:-self.p]
    pad = torch.zeros(self.p, x.shape[1])
    return torch.cat([pad, dx], dim = 0)

  def __len__(self):
    return len(self.dataframes)

  def __getitem__(self, idx):
    df = self.dataframes[idx]
    x = torch.tensor(df[self.sensor_cols].values, dtype = torch.float32)

    # dataset level normalization because of varying row number
    x = (x-self.mean) / (self.std + 1e-6)
    if self.use_fotd:
      x = self.fotd(x)

    # pad/truncate the time dimension
    T, D = x.shape
    if T < self.max_len:
      pad = torch.zeros(self.max_len - T, D)
      x = torch.cat([x,pad], dim = 0)
    else:
      x = x[:self.max_len]

    y = torch.tensor(self.labels[idx], dtype = torch.long)
    return x, y


Functions to use in training

In [61]:
def training(model, loader, optimizer, criterion, device, use_hebbian = False, hebb_lr = 1e-3):
  """
  Purpose: execute an epoch of training w/ optional updates to Hebbian plasticity

  params
  model: OLM to train
  loader: training data loader
  optimizer: pytorch optimizer, in this case Adam
  criterion: loss function, in this case cross entropy
  device: cuda or cpu, used A100
  use_hebbian: if True, hebbian updates are applied to piriform (happens after
  gradient based updates)
  hebb_lr: learning rate for hebbian updates

  returns
  avg_loss: average loss (across all batches)
  acc: training accuracy (computed from argmax of logits - predictions)
  """
  model.train()
  total_loss = 0
  correct = 0
  total = 0

  for x, y in loader:
    x,y = x.to(device), y.to(device)

    optimizer.zero_grad()
    logits, bulb, pc = model(x)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()

    if use_hebbian:
      with torch.no_grad():
        model.piriform.hebbian_update(bulb, pc, lr = hebb_lr)


    total_loss += loss.item()
    _, predicted = logits.max(1)
    total += y.size(0)
    correct += predicted.eq(y).sum().item()

  avg_loss = total_loss / len(loader)
  acc = 100. * correct / total

  return avg_loss, acc

@torch.no_grad()
def evaluate(model, loader, device, k = 5):
  """
  Purpose: evaluate model perform. on test data with topk accuracy.
  using top1 and top5 for this.

  params
  model: olm in this case
  loader: dataloader for test
  device: cuda or cpu, used A100
  k: number of top predictions to consider

  returns
  top1_acc: top1 accuracy (is the model's first choice correct?)
  topk_acc: topk accuracy (is the correct choice within the model's top k predictions)
  k here = 5

  """
  model.eval()
  correct1 = 0
  correctk = 0
  total = 0

  for x, y in loader:
    x,y = x.to(device), y.to(device)
    logits, _, _ = model(x)

    _, topk = logits.topk(k, dim = 1)
    correct1 += (topk[:, 0] == y).sum().item()
    correctk += (topk == y.unsqueeze(1)).any(dim=1).sum().item()
    total += y.size(0)
  top1_acc = 100. * correct1 / total
  topk_acc = 100. * correctk / total

  return top1_acc, topk_acc


@torch.no_grad()
def evaluate_by_category(model, dataframes, labels, device, ingredient_to_category, label_to_odor,
                         mean, std, use_fotd = False, p = 25, max_len = None,
                         k = 5):
  """
  Purpose: analyzes model performance based on category like they do in Feng et
  al., 2025 - categories as given by the paper

  params
  model: olm in this case
  dataframes: raw gas sensor data
  labels: integer labels for each SAMPLE
  device: cuda or cpu, used A100
  ingredient_to_category: dictionary, maps the odor label to the category
  label_to_odor: dictionary, maps the integer label to the odor label
  k: topk parameter, only really used k = 5 here

  returns
  category_results: dictionary, maps the category name to the performance metrics
  (top1_acc, top5_acc)
  """
  model.eval()
  sensor_cols = ["NO2", "C2H5OH", "VOC", "CO", "Alcohol",
                        "LPG", "Benzene", "Temperature", "Pressure",
                        "Humidity", "Gas_Resistance", "Altitude"]
  if max_len is None:
    max_len = max(len(df) for df in dataframes)

  def fotd(x, p):
    dx = x[p:] - x[:-p]
    pad = torch.zeros(p, x.shape[1])
    return torch.cat([pad, dx], dim = 0)

  category_stats = {}
  for category in set(ingredient_to_category.values()):
    category_stats[category] = {
        'correct1': 0,
        'correct5' :0,
        'total': 0
    }
  for i, (df, label) in enumerate(zip(dataframes, labels)):
    odor_name = label_to_odor[label]
    category = ingredient_to_category.get(odor_name, 'Unknown')

    x = torch.tensor(df[sensor_cols].values, dtype = torch.float32)
    x = (x - train_mean) / (train_std + 1e-6)

    if use_fotd:
      x = fotd(x, p)

    T, D = x.shape
    max_len = train_dataset.max_len
    if T < max_len:
      pad = torch.zeros(max_len - T, D)
      x = torch.cat([x, pad], dim = 0)
    else:
      x = x[:max_len]

    x = x.unsqueeze(0).to(device)

    # predictions
    logits, _, _ = model(x)
    _, topk = logits.topk(k, dim = 1)

    # updates
    y_tensor = torch.tensor([label]).to(device)
    category_stats[category]['correct1'] += (topk[:,0] == y_tensor).sum().item()
    category_stats[category]['correct5'] += (topk == y_tensor.unsqueeze(1)).any(dim = 1).sum().item()
    category_stats[category]['total'] += 1

  category_results = {}
  for category, stats in category_stats.items():
    if stats['total'] > 0:
      category_results[category] = {
          'top1_acc': 100. * stats['correct1'] / stats['total'],
          'top5_acc': 100. * stats['correct5'] / stats['total'],
          'n_samples' : stats['total']
      }

  return category_results

Loading in the datasets and data setup, model initialization

In [62]:
def compute_sens_stat(dataframes, sensor_cols):
  """
  Purpose: compute the mean and std of sensor readings across the entire dataset
  so that normalization can be carried out.

  params
  dataframes: all the csv files
  sensor_cols: names of the features in the gas sensor data

  returns
  mean: mean of each feature
  std: standard deviation of each feature
  """
  all_inp = []
  for df in dataframes:
    all_inp.append(torch.tensor(df[sensor_cols].values, dtype = torch.float32))
  all_inp = torch.cat(all_inp, dim = 0)

  mean = all_inp.mean(dim=0)
  std = all_inp.std(dim =0)
  return mean, std

sensor_cols = ["NO2", "C2H5OH", "VOC", "CO", "Alcohol",
                        "LPG", "Benzene", "Temperature", "Pressure",
                        "Humidity", "Gas_Resistance", "Altitude"]
# taken directly from smellnet's code
ingredient_to_category = {
    # Nuts
    "peanuts": "Nuts",
    "cashew": "Nuts",
    "chestnuts": "Nuts",
    "pistachios": "Nuts",
    "almond": "Nuts",
    "hazelnut": "Nuts",
    "walnuts": "Nuts",
    "pecans": "Nuts",
    "brazil_nut": "Nuts",
    "pili_nut": "Nuts",

    # Spices
    "cumin": "Spices",
    "star_anise": "Spices",
    "nutmeg": "Spices",
    "cloves": "Spices",
    "ginger": "Spices",
    "allspice": "Spices",
    "chervil": "Spices",
    "mustard": "Spices",
    "cinnamon": "Spices",
    "saffron": "Spices",

    # Herbs
    "angelica": "Herbs",
    "garlic": "Herbs",
    "chives": "Herbs",
    "turnip": "Herbs",
    "dill": "Herbs",
    "mugwort": "Herbs",
    "chamomile": "Herbs",
    "coriander": "Herbs",
    "oregano": "Herbs",
    "mint": "Herbs",

    # Fruits
    "kiwi": "Fruits",
    "pineapple": "Fruits",
    "banana": "Fruits",
    "lemon": "Fruits",
    "mandarin_orange": "Fruits",
    "strawberry": "Fruits",
    "apple": "Fruits",
    "mango": "Fruits",
    "peach": "Fruits",
    "pear": "Fruits",

    # Vegetables
    "cauliflower": "Vegetables",
    "brussel_sprouts": "Vegetables",
    "broccoli": "Vegetables",
    "sweet_potato": "Vegetables",
    "asparagus": "Vegetables",
    "avocado": "Vegetables",
    "radish": "Vegetables",
    "tomato": "Vegetables",
    "potato": "Vegetables",
    "cabbage": "Vegetables",
}
train_mean, train_std = compute_sens_stat(offline_training_data, sensor_cols)

# datasets
# training data
train_dataset = SmellNetData(offline_training_data, offline_training_labels,
                             mean = train_mean, std = train_std, use_fotd = False)

# testing data
test_dataset = SmellNetData(offline_testing_data, offline_testing_labels,
                            mean = train_mean, std = train_std, use_fotd = False)

# fotd on
train_datasetf = SmellNetData(offline_training_data, offline_training_labels,
                             mean = train_mean, std = train_std, use_fotd = True, p = 25)
test_datasetf = SmellNetData(offline_testing_data, offline_testing_labels,
                            mean = train_mean, std = train_std, use_fotd = True, p = 25)
# online data

#absolute
online_nut_dataseta = SmellNetData(online_nuts_data, online_nuts_labels, mean = train_mean,
                                  std = train_std, use_fotd = False)
online_spices_dataseta = SmellNetData(online_spices_data, online_spices_labels,
                                     mean = train_mean, std = train_std, use_fotd =False)

#fotd on
online_nut_datasetf = SmellNetData(online_nuts_data, online_nuts_labels, mean = train_mean,
                                  std = train_std, use_fotd = True, p = 25)
online_spices_datasetf = SmellNetData(online_spices_data, online_spices_labels,
                                     mean = train_mean, std = train_std, use_fotd = True, p = 25)

# dataloaders
batch = 16
#offline
train_loader = DataLoader(train_dataset, batch_size = batch, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch, shuffle = False)

train_loaderf = DataLoader(train_datasetf, batch_size = batch, shuffle = True)
test_loaderf = DataLoader(test_datasetf, batch_size = batch, shuffle = False)

#online
online_nut_loadera = DataLoader(online_nut_dataseta, batch_size = batch, shuffle = False)
online_spices_loadera = DataLoader(online_spices_dataseta, batch_size = batch, shuffle = False)
online_nut_loaderf = DataLoader(online_nut_datasetf, batch_size = batch, shuffle = False)
online_spices_loaderf = DataLoader(online_spices_datasetf, batch_size = batch, shuffle = False)

# sanity check
n_classes = len(odor_to_label)
print(f"\nNumber of Classes: {n_classes}")



# model initialization
device = "cuda" if torch.cuda.is_available() else "cpu"



Number of Classes: 50


Training loop

In [None]:
lr = 1e-4
n_receptors = 500
n_mitral = 200
n_piric = 400
n_class = 50


#model
olm1 = OLM(input_dim = 12, n_receptors = n_receptors, n_mitral = n_mitral, n_piric = n_piric, n_class = n_classes,
           lateral_strength = 0.3).to(device)
optimizer1 = torch.optim.Adam(olm1.parameters(), lr = lr)
criterion = nn.CrossEntropyLoss()

num_epochs = 200
train_losses = []
train_accs = []
test_top1accs = []
test_top5accs = []

results1 = {}

# TEST 1: absolute readings
for epoch in range(num_epochs):
  # training
  train_loss, train_acc = training(
      olm1, train_loader, optimizer1, criterion, device, use_hebbian = True, hebb_lr = 1e-4
  ) # hebbian can be toggled to true
  test_top1, test_top5 = evaluate(olm1, test_loader, device, k = 5)

  # storage
  train_losses.append(train_loss)
  train_accs.append(train_acc)
  test_top1accs.append(test_top1)
  test_top5accs.append(test_top5)

  if (epoch + 1) % 10 == 0 or epoch == 0:
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f" Test top1: {test_top1:.2f}%, Test top5: {test_top5:.2f}%")


results1['train_losses'] = train_losses
results1['train_accs'] = train_accs
results1['test_top1accs'] = test_top1accs
results1['test_top5accs'] = test_top5accs





Epoch 1/200
 Train Loss: 8.5215, Train Acc: 2.00%
 Test top1: 2.00%, Test top5: 10.00%
Epoch 10/200
 Train Loss: 4.2841, Train Acc: 4.80%
 Test top1: 6.00%, Test top5: 24.00%
Epoch 20/200
 Train Loss: 3.6645, Train Acc: 15.20%
 Test top1: 24.00%, Test top5: 40.00%
Epoch 30/200
 Train Loss: 2.9782, Train Acc: 24.80%
 Test top1: 22.00%, Test top5: 68.00%
Epoch 40/200
 Train Loss: 2.5874, Train Acc: 35.60%
 Test top1: 42.00%, Test top5: 76.00%
Epoch 50/200
 Train Loss: 2.1916, Train Acc: 44.00%
 Test top1: 32.00%, Test top5: 80.00%
Epoch 60/200
 Train Loss: 1.8644, Train Acc: 41.20%
 Test top1: 44.00%, Test top5: 86.00%
Epoch 70/200
 Train Loss: 1.6832, Train Acc: 46.80%
 Test top1: 56.00%, Test top5: 80.00%
Epoch 80/200
 Train Loss: 1.5480, Train Acc: 52.80%
 Test top1: 52.00%, Test top5: 84.00%
Epoch 90/200
 Train Loss: 1.3254, Train Acc: 59.20%
 Test top1: 62.00%, Test top5: 86.00%
Epoch 100/200
 Train Loss: 1.1595, Train Acc: 61.60%
 Test top1: 60.00%, Test top5: 92.00%
Epoch 110/200


In [None]:
#TEST 2: now with fotd pre-processing
olm2 = OLM(input_dim = 12, n_receptors = n_receptors, n_mitral = n_mitral, n_piric = n_piric, n_class = n_classes,
           lateral_strength = 0.3).to(device)
optimizer2 = torch.optim.Adam(olm2.parameters(), lr = lr)

training_loss = []
training_acc = []
testing_top1 = []
testing_top5 = []

results2 = {}
for epoch in range(num_epochs):
  train_loss, train_acc = training(
      olm2, train_loaderf, optimizer2, criterion, device, use_hebbian = True, hebb_lr = 1e-4
  )
  test_top1, test_top5 = evaluate(olm2, test_loaderf, device, k = 5)

  # storage
  training_loss.append(train_loss)
  training_acc.append(train_acc)
  testing_top1.append(test_top1)
  testing_top5.append(test_top5)

  if (epoch + 1) % 10 == 0 or epoch == 0:
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f" Test top1: {test_top1:.2f}%, Test top5: {test_top5:.2f}%")


results2['train_losses'] = training_loss
results2['train_accs'] = training_acc
results2['test_top1accs'] = testing_top1
results2['test_top5accs'] = testing_top5


In [None]:
# TEST 3: evaluation of online learning using nuts and spices datasets
# absolute, nuts
olm3 = copy.deepcopy(olm1)
olm3.bulb.gru.flatten_parameters()
og_classifier_weight = olm3.classifier.weight.data.clone()
og_classifier_bias = olm3.classifier.bias.data.clone() if olm3.classifier.bias is not None else None
og_piriform_w = olm3.piriform.w.data.clone()

#before and after
correct1_before_nuts = 0
correct5_before_nuts = 0

correct1_after_nuts = 0
correct5_after_nuts = 0
total_nuts = 0

# optimizer
adapt_lr = 1e-5
n_adapt_steps =5
adapt_optimizer = torch.optim.Adam(olm3.classifier.parameters(), lr = adapt_lr)

for x, y in online_nut_loadera:
  x, y = x.to(device), y.to(device)

  # before adap.
  olm3.eval()
  with torch.no_grad():
    logits, bulb, pc = olm3(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_before_nuts += (topk[:,0] == y).sum().item()
    correct5_before_nuts += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  # online adaptation
  olm3.train()
  for step in range(n_adapt_steps):
    adapt_optimizer.zero_grad()
    logits, bulb, pc = olm3(x)
    loss = criterion(logits, y)
    loss.backward()
    adapt_optimizer.step()

    # hebb update during adap.
    with torch.no_grad():
      olm3.piriform.hebbian_update(bulb, pc, lr = 1e-5)

  # after adap.
  olm3.eval()
  with torch.no_grad():
    logits, _ , _ = olm3(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_after_nuts += (topk[:,0] == y).sum().item()
    correct5_after_nuts += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  total_nuts += y.size(0)
#accuracy calcs
top1_before_nuts = 100. * correct1_before_nuts / total_nuts
top5_before_nuts = 100. * correct5_before_nuts / total_nuts
top1_after_nuts = 100. * correct1_after_nuts / total_nuts
top5_after_nuts = 100. * correct5_after_nuts / total_nuts

improvement1na = top1_after_nuts - top1_before_nuts
improvement5na = top5_after_nuts - top5_before_nuts

# reset for use for spices
olm3.classifier.weight.data = og_classifier_weight.clone()
if og_classifier_bias is not None:
    olm3.classifier.bias.data = og_classifier_bias.clone()
olm3.piriform.w.data = og_piriform_w.clone()


# fotd, nuts
olm4 = copy.deepcopy(olm2)
olm4.bulb.gru.flatten_parameters()
og_classifier_weightf = olm4.classifier.weight.data.clone()
og_classifier_biasf = olm4.classifier.bias.data.clone() if olm4.classifier.bias is not None else None
og_piriform_wf = olm4.piriform.w.data.clone()

correct1_before_nutsf = 0
correct5_before_nutsf = 0

correct1_after_nutsf = 0
correct5_after_nutsf = 0
total_nutsf = 0

# optimizer
adapt_lrf = 1e-5
n_adapt_stepsf =5
adapt_optimizerf = torch.optim.Adam(olm4.classifier.parameters(), lr = adapt_lrf)

for x, y in online_nut_loaderf:
  x, y = x.to(device), y.to(device)

  # before adap.
  olm4.eval()
  with torch.no_grad():
    logits, bulb, pc = olm4(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_before_nutsf += (topk[:,0] == y).sum().item()
    correct5_before_nutsf += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  # online adaptation
  olm4.train()
  for step in range(n_adapt_stepsf):
    adapt_optimizerf.zero_grad()
    logits, bulb, pc = olm4(x)
    loss = criterion(logits, y)
    loss.backward()
    adapt_optimizerf.step()

    # hebb update during adap.
    with torch.no_grad():
      olm4.piriform.hebbian_update(bulb, pc, lr = 1e-5)
  # after adap.
  olm4.eval()
  with torch.no_grad():
    logits, _ , _ = olm4(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_after_nutsf += (topk[:,0] == y).sum().item()
    correct5_after_nutsf += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  total_nutsf += y.size(0)
#accuracy calcs
top1_before_nutsf = 100. * correct1_before_nutsf / total_nutsf
top5_before_nutsf = 100. * correct5_before_nutsf / total_nutsf
top1_after_nutsf = 100. * correct1_after_nutsf / total_nutsf
top5_after_nutsf = 100. * correct5_after_nutsf / total_nutsf

improvement1nf = top1_after_nutsf - top1_before_nutsf
improvement5nf = top5_after_nutsf - top5_before_nutsf

# reset for use for spices
olm4.classifier.weight.data = og_classifier_weightf.clone()
if og_classifier_biasf is not None:
    olm4.classifier.bias.data = og_classifier_biasf.clone()
olm4.piriform.w.data = og_piriform_wf.clone()


#absolute, spices
#before and after
correct1_before_spices = 0
correct5_before_spices = 0

correct1_after_spices = 0
correct5_after_spices = 0
total_spices = 0

# optimizer
adapt_optimizers = torch.optim.Adam(olm3.classifier.parameters(), lr = adapt_lr)

for x, y in online_spices_loadera:
  x, y = x.to(device), y.to(device)

  # before adap.
  olm3.eval()
  with torch.no_grad():
    logits, bulb, pc = olm3(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_before_spices += (topk[:,0] == y).sum().item()
    correct5_before_spices += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  # online adaptation
  olm3.train()
  for step in range(n_adapt_steps):
    adapt_optimizers.zero_grad()
    logits, bulb, pc = olm3(x)
    loss = criterion(logits, y)
    loss.backward()
    adapt_optimizers.step()

    # hebb update during adap.
    with torch.no_grad():
      olm3.piriform.hebbian_update(bulb, pc, lr = 1e-5)

  # after adap.
  olm3.eval()
  with torch.no_grad():
    logits, _ , _ = olm3(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_after_spices += (topk[:,0] == y).sum().item()
    correct5_after_spices += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  total_spices += y.size(0)

#accuracy calcs
top1_before_spices = 100. * correct1_before_spices / total_spices
top5_before_spices = 100. * correct5_before_spices / total_spices
top1_after_spices = 100. * correct1_after_spices / total_spices
top5_after_spices = 100. * correct5_after_spices / total_spices

improvement1sa = top1_after_spices - top1_before_spices
improvement5sa = top5_after_spices - top5_before_spices

# reset just in general
olm3.classifier.weight.data = og_classifier_weight.clone()
if og_classifier_bias is not None:
    olm3.classifier.bias.data = og_classifier_bias.clone()
olm3.piriform.w.data = og_piriform_w.clone()




# fotd, spices
correct1_before_spicesf = 0
correct5_before_spicesf = 0

correct1_after_spicesf = 0
correct5_after_spicesf = 0
total_spicesf = 0

# optimizer
adapt_lrs = 1e-5
n_adapt_stepss =5
adapt_optimizers = torch.optim.Adam(olm4.classifier.parameters(), lr = adapt_lrf)

for x, y in online_spices_loaderf:
  x, y = x.to(device), y.to(device)

  # before adap.
  olm4.eval()
  with torch.no_grad():
    logits, bulb, pc = olm4(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_before_spicesf += (topk[:,0] == y).sum().item()
    correct5_before_spicesf += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  # online adaptation
  olm4.train()
  for step in range(n_adapt_stepss):
    adapt_optimizers.zero_grad()
    logits, bulb, pc = olm4(x)
    loss = criterion(logits, y)
    loss.backward()
    adapt_optimizers.step()

    # hebb update during adap.
    with torch.no_grad():
      olm4.piriform.hebbian_update(bulb, pc, lr = 1e-5)
  # after adap.
  olm4.eval()
  with torch.no_grad():
    logits, _ , _ = olm4(x)
    _, topk = logits.topk(5, dim = 1)
    correct1_after_spicesf += (topk[:,0] == y).sum().item()
    correct5_after_spicesf += (topk == y.unsqueeze(1)).any(dim = 1).sum().item()

  total_spicesf += y.size(0)
#accuracy calcs
top1_before_spicesf = 100. * correct1_before_spicesf / total_spicesf
top5_before_spicesf = 100. * correct5_before_spicesf / total_spicesf
top1_after_spicesf = 100. * correct1_after_spicesf / total_spicesf
top5_after_spicesf = 100. * correct5_after_spicesf / total_spicesf

improvement1sf = top1_after_spicesf - top1_before_spicesf
improvement5sf = top5_after_spicesf - top5_before_spicesf

# reset in general
olm4.classifier.weight.data = og_classifier_weightf.clone()
if og_classifier_biasf is not None:
    olm4.classifier.bias.data = og_classifier_biasf.clone()
olm4.piriform.w.data = og_piriform_wf.clone()



print

In [None]:
print("online testing - nuts")
print("\nAbsolute Readings Model:")
print(f"  Top-1 Accuracy Before: {top1_before_nuts:.2f}%")
print(f"  Top-1 Accuracy After:  {top1_after_nuts:.2f}%")
print(f"  Top-1 Improvement:     {improvement1na:.2f}%")
print(f"  Top-5 Accuracy Before: {top5_before_nuts:.2f}%")
print(f"  Top-5 Accuracy After:  {top5_after_nuts:.2f}%")
print(f"  Top-5 Improvement:     {improvement5na:.2f}%")

print("\nFOTD Model:")
print(f"  Top-1 Accuracy Before: {top1_before_nutsf:.2f}%")
print(f"  Top-1 Accuracy After:  {top1_after_nutsf:.2f}%")
print(f"  Top-1 Improvement:     {improvement1nf:.2f}%")
print(f"  Top-5 Accuracy Before: {top5_before_nutsf:.2f}%")
print(f"  Top-5 Accuracy After:  {top5_after_nutsf:.2f}%")
print(f"  Top-5 Improvement:     {improvement5nf:.2f}%")

print("online testing - spices")
print("\nAbsolute Readings Model:")
print(f"  Top-1 Accuracy Before: {top1_before_spices:.2f}%")
print(f"  Top-1 Accuracy After:  {top1_after_spices:.2f}%")
print(f"  Top-1 Improvement:     {improvement1sa:.2f}%")
print(f"  Top-5 Accuracy Before: {top5_before_spices:.2f}%")
print(f"  Top-5 Accuracy After:  {top5_after_spices:.2f}%")
print(f"  Top-5 Improvement:     {improvement5sa:.2f}%")

print("\nFOTD Model:")
print(f"  Top-1 Accuracy Before: {top1_before_spicesf:.2f}%")
print(f"  Top-1 Accuracy After:  {top1_after_spicesf:.2f}%")
print(f"  Top-1 Improvement:     {improvement1sf:.2f}%")
print(f"  Top-5 Accuracy Before: {top5_before_spicesf:.2f}%")
print(f"  Top-5 Accuracy After:  {top5_after_spicesf:.2f}%")
print(f"  Top-5 Improvement:     {improvement5sf:.2f}%")


# offline testing by cat
print("offline testing")

# absolute model
print("\nAbsolute Readings Model:")
test_cat_results_abs = evaluate_by_category(
    olm1, offline_testing_data, offline_testing_labels,
    device, ingredient_to_category, label_to_odor, mean =
    train_mean, std = train_std, use_fotd = False, max_len = test_dataset.max_len,
    k=5
)

for category in sorted(test_cat_results_abs.keys()):
    results = test_cat_results_abs[category]
    print(f"\n{category}:")
    print(f" Samples: {results['n_samples']}")
    print(f"Top-1 Accuracy: {results['top1_acc']:.2f}%")
    print(f"Top-5 Accuracy: {results['top5_acc']:.2f}%")

# fOTD model
print("\nFOTD Model:")
test_cat_results_fotd = evaluate_by_category(
    olm2, offline_testing_data, offline_testing_labels,
    device, ingredient_to_category, label_to_odor,
    mean = train_mean, std = train_std, use_fotd = True,
    p = 25, max_len = train_datasetf.max_len, k=5
)

for category in sorted(test_cat_results_fotd.keys()):
    results = test_cat_results_fotd[category]
    print(f"\n{category}:")
    print(f" Samples: {results['n_samples']}")
    print(f" Top-1 Accuracy: {results['top1_acc']:.2f}%")
    print(f" Top-5 Accuracy: {results['top5_acc']:.2f}%")



Visualization

In [None]:
# loss plt
plt.figure(figsize = (24,16))
plt.subplot(3,4,1)
plt.plot(results1['train_losses'], label = 'Absolute')
plt.plot(results2['train_losses'], label = 'FOTD')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.grid(True, alpha = 0.3)

# accuracy plt
plt.subplot(3,4,2)
plt.plot(results1['train_accs'], label = 'Absolute')
plt.plot(results2['train_accs'], label = 'FOTD')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training acc')
plt.legend()
plt.grid(True, alpha = 0.3)

# test top1 acc
plt.subplot(3,4,3)
plt.plot(results1['test_top1accs'], label = 'Absolute')
plt.plot(results2['test_top1accs'], label = 'FOTD')
plt.xlabel('Epoch')
plt.ylabel('top 1 Accuracy (%)')
plt.title('Testing top 1 accuracy')
plt.legend()
plt.grid(True, alpha = 0.3)

# test top5 acc
plt.subplot(3,4,4)
plt.plot(results1['test_top5accs'], label = 'Absolute')
plt.plot(results2['test_top5accs'], label = 'FOTD')
plt.xlabel('Epoch')
plt.ylabel('top 5 Accuracy (%)')
plt.title('Testing top 5 accuracy')
plt.legend()
plt.grid(True, alpha = 0.3)



# online learning for nuts
plt.subplot(3, 4, 5)
categories = ['Before', 'After']
abs_nuts_top1 = [top1_before_nuts, top1_after_nuts]
fotd_nuts_top1 = [top1_before_nutsf, top1_after_nutsf]
x = np.arange(len(categories))
width = 0.35
plt.bar(x - width/2, abs_nuts_top1, width, label='Absolute', alpha=0.8)
plt.bar(x + width/2, fotd_nuts_top1, width, label='FOTD', alpha=0.8)
plt.xlabel('Adaptation Stage', fontsize=12)
plt.ylabel('Top1 Accuracy (%)', fontsize=12)
plt.title('online learning - nuts', fontsize=14, fontweight='bold')
plt.xticks(x, categories)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, axis='y')

plt.subplot(3, 4, 6)
abs_nuts_top5 = [top5_before_nuts, top5_after_nuts]
fotd_nuts_top5 = [top5_before_nutsf, top5_after_nutsf]
plt.bar(x - width/2, abs_nuts_top5, width, label='Absolute', alpha=0.8)
plt.bar(x + width/2, fotd_nuts_top5, width, label='FOTD', alpha=0.8)
plt.xlabel('Adaptation Stage', fontsize=12)
plt.ylabel('Top5 Accuracy (%)', fontsize=12)
plt.title('online learning - nuts', fontsize=14, fontweight='bold')
plt.xticks(x, categories)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, axis='y')

# online learning for spices
plt.subplot(3, 4, 7)
abs_spices_top1 = [top1_before_spices, top1_after_spices]
fotd_spices_top1 = [top1_before_spicesf, top1_after_spicesf]
plt.bar(x - width/2, abs_spices_top1, width, label='Absolute', alpha=0.8)
plt.bar(x + width/2, fotd_spices_top1, width, label='FOTD', alpha=0.8)
plt.xlabel('Adaptation Stage', fontsize=12)
plt.ylabel('Top1 Accuracy (%)', fontsize=12)
plt.title('online learning - spices', fontsize=14, fontweight='bold')
plt.xticks(x, categories)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, axis='y')

plt.subplot(3, 4, 8)
abs_spices_top5 = [top5_before_spices, top5_after_spices]
fotd_spices_top5 = [top5_before_spicesf, top5_after_spicesf]
plt.bar(x - width/2, abs_spices_top5, width, label='Absolute', alpha=0.8)
plt.bar(x + width/2, fotd_spices_top5, width, label='FOTD', alpha=0.8)
plt.xlabel('Adaptation Stage', fontsize=12)
plt.ylabel('Top5 Accuracy (%)', fontsize=12)
plt.title('online learning - spices', fontsize=14, fontweight='bold')
plt.xticks(x, categories)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, axis='y')



# by category
plt.subplot(3, 4, 9)
categories = sorted(test_cat_results_abs.keys())
abs_top1 = [test_cat_results_abs[cat]['top1_acc'] for cat in categories]
fotd_top1 = [test_cat_results_fotd[cat]['top1_acc'] for cat in categories]
x = np.arange(len(categories))
width = 0.35
plt.bar(x - width/2, abs_top1, width, label='Absolute', alpha=0.8)
plt.bar(x + width/2, fotd_top1, width, label='FOTD', alpha=0.8)
plt.xlabel('Category', fontsize=12)
plt.ylabel('top1 Accuracy (%)', fontsize=12)
plt.title('category performance', fontsize=14, fontweight='bold')
plt.xticks(x, categories, rotation=45, ha='right')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, axis='y')

plt.subplot(3, 4, 10)
abs_top5 = [test_cat_results_abs[cat]['top5_acc'] for cat in categories]
fotd_top5 = [test_cat_results_fotd[cat]['top5_acc'] for cat in categories]
plt.bar(x - width/2, abs_top5, width, label='Absolute', alpha=0.8)
plt.bar(x + width/2, fotd_top5, width, label='FOTD', alpha=0.8)
plt.xlabel('Category', fontsize=12)
plt.ylabel('top5 Accuracy (%)', fontsize=12)
plt.title('category performance', fontsize=14, fontweight='bold')
plt.xticks(x, categories, rotation=45, ha='right')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, axis='y')

plt.subplot(3, 4, 11)
n_samples = [test_cat_results_abs[cat]['n_samples'] for cat in categories]
plt.bar(categories, n_samples, alpha=0.8, color='steelblue')
plt.xlabel('Category', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.title('Test Set Distribution', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3, axis='y')


plt.tight_layout()
plt.savefig('olm_comprehensive_results.png', dpi=300, bbox_inches='tight')
plt.show()


# saving model
save_dir = 'content/drive/MyDrive/3500final'
os.makedirs(save_dir, exist_ok=True)
torch.save(olm1.state_dict(), os.path.join(save_dir, 'olm_model1.pth'))
print("\nModel saved as 'olm_model1.pth")
torch.save(olm2.state_dict(), os.path.join(save_dir, 'olm_model2.pth'))
print("\nModel saved as 'olm_model2.pth")
torch.save(olm3.state_dict(), os.path.join(save_dir, 'olm_model3.pth'))
print("\nModel saved as 'olm_model3.pth")
torch.save(olm4.state_dict(), os.path.join(save_dir, 'olm_model4.pth'))
print("\nModel saved as 'olm_model4.pth")

