In [1]:
import numpy as np
import pandas as pd

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch import optim
from torch.nn import functional as F
import lightning as L

In [30]:
L.seed_everything(42)

Seed set to 42


42

In [4]:
abundance = pd.read_csv("../../results/data/prepared/processed_abundance.csv")
groups = pd.read_csv("../../results/data/prepared/groups.csv")

abundance = abundance.pivot(index="sample", columns="glycan", values="value")
abundance = pd.DataFrame(np.log2(abundance.values), columns=abundance.columns, index=abundance.index)
groups = groups.set_index("sample")
data = pd.merge(abundance, groups, left_index=True, right_index=True, how="left")
data = data[data["group"] != "QC"]

groups4 = data["group"]

data["group"] = data["group"] == "C"
data["group"] = data["group"].astype(int)

In [5]:
data.head()

Unnamed: 0_level_0,H3N3,H3N3F1,H3N4,H3N4F1,H3N5,H3N5F1,H4N2,H4N3,H4N3F1,H4N3F1S1,...,H7N6F1S2,H7N6F1S3,H7N6F1S4,H7N6S1,H7N6S2,H7N6S3,H7N6S4,H8N2,H9N2,group
sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
D1,10.74885,10.895486,11.347862,17.077142,13.065964,14.451495,8.658876,11.412047,10.863957,11.886809,...,11.548488,11.774836,12.04017,11.928917,12.229021,12.175362,12.302443,14.225364,14.745809,1
D10,11.043689,11.023918,11.717196,15.696562,12.155633,13.548614,9.236748,11.71613,11.134252,12.062038,...,9.766474,10.82642,11.99995,8.236748,12.389489,12.908261,13.951431,14.755339,15.170621,0
D100,11.871526,11.800639,11.713925,15.496599,11.220893,13.002675,9.603057,12.200724,11.904027,12.653039,...,10.391143,10.658086,11.017888,10.268883,12.791517,13.311963,13.744936,14.973595,14.412574,0
D101,12.11613,11.739967,11.665591,16.088801,11.387755,13.677937,10.003786,12.35799,11.728729,8.993963,...,10.641457,10.712676,11.032578,12.252853,13.273095,13.383218,13.517158,14.988095,15.355569,0
D102,11.589243,11.763432,11.139805,16.338681,11.58031,13.171035,9.309983,12.107133,11.750405,12.192766,...,10.041073,10.018495,10.770794,12.140416,12.94609,13.394846,14.028628,14.539586,14.250261,0


In [6]:
data.shape

(720, 63)

In [7]:
from sklearn.model_selection import train_test_split

total_train_data, test_data, total_train_groups4, _ = train_test_split(
    data, groups4, test_size=128, random_state=42, stratify=groups4, shuffle=True
)
train_data, val_data = train_test_split(
    total_train_data, test_size=128, random_state=42, stratify=total_train_groups4, shuffle=True
)

print("Training size:", train_data.shape[0])
print("Validation size:", val_data.shape[0])
print("Testing size:", test_data.shape[0])

Training size: 464
Validation size: 128
Testing size: 128


In [8]:
DATA_MEAN = train_data.drop('group', axis=1).values.mean()
DATA_STD = train_data.drop('group', axis=1).values.std()

print("Mean:", DATA_MEAN)
print("Std:", DATA_STD)

Mean: 13.630876642212709
Std: 2.3228725257250544


In [9]:
class GlycanDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        x = torch.tensor(row[:-1].values, dtype=torch.float32)
        if self.transform:
            x = self.transform(x)
        y = torch.tensor(row[-1], dtype=torch.float32)
        return x, y

In [10]:
normalizer = lambda x: (x - DATA_MEAN) / DATA_STD

train_set = GlycanDataset(train_data, transform=normalizer)
val_set = GlycanDataset(val_data, transform=normalizer)
test_set = GlycanDataset(test_data, transform=normalizer)

In [11]:
train_set[0]

(tensor([-0.8260, -0.8260, -0.8260,  0.8291, -0.8260,  0.0692, -0.8260, -0.8260,
         -0.8260, -0.8260,  0.4034, -0.3955,  1.1168,  0.6862,  0.0687,  0.2070,
          0.5013,  0.1339, -0.8260,  0.1304, -0.8260, -0.8260,  0.5343, -0.0102,
          0.7880,  1.3160,  1.5491, -0.8260, -0.8260,  1.2886,  2.5426, -0.8260,
          0.4412,  1.1227,  1.0018,  0.9241, -0.8260,  0.5234, -0.8260, -0.8260,
          0.3863, -0.8260, -0.8260, -0.0209,  0.7313,  1.3888,  0.1713,  0.4232,
          0.2328,  1.0858, -0.8260, -0.8260,  0.1319, -0.8260, -0.8260, -0.8260,
         -0.8260, -0.8260, -0.8260, -0.8260,  0.4597,  0.4394]),
 tensor(1.))

In [12]:
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, drop_last=True, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, drop_last=False)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False, drop_last=False)

In [20]:
class GlycanModel(L.LightningModule):
    def __init__(self, layers, dropouts, lr=1e-3, weight_decay=1e-5):
        super().__init__()
        assert len(layers) == len(dropouts) + 1
        self.save_hyperparameters()
        self.net = self._create_net()
        self.loss = nn.BCEWithLogitsLoss()
        self.example_input_array = torch.zeros(16, 62)
    
    def _create_net(self):
        layers = []
        for i, (in_features, out_features) in enumerate(zip(self.hparams.layers[:-1], self.hparams.layers[1:])):
            layers.append(nn.Linear(in_features, out_features))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(self.hparams.dropouts[i]))
        layers.append(nn.Linear(self.hparams.layers[-1], 1))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [scheduler]

    def _get_loss_and_acc(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).squeeze()
        loss = self.loss(y_hat, y)
        probs = torch.sigmoid(y_hat)
        preds = (probs > 0.5).float()
        acc = (preds == y).float().mean()
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self._get_loss_and_acc(batch, batch_idx)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._get_loss_and_acc(batch, batch_idx)
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def test_step(self, batch, batch_idx):
        loss, acc = self._get_loss_and_acc(batch, batch_idx)
        self.log("test_loss", loss)
        self.log("test_acc", acc)

In [39]:
model = GlycanModel(layers=[abundance.shape[1], 128, 64, 16], dropouts=[0.3, 0.3, 0.0], lr=1e-3, weight_decay=1e-5)
early_stop_callback = L.pytorch.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
logger = L.pytorch.loggers.TensorBoardLogger("lightning_logs", name="glycan")
trainer = L.Trainer(max_epochs=180, logger=logger, log_every_n_steps=1, callbacks=[early_stop_callback])
trainer.fit(model, train_loader, val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type              | Params | In sizes | Out sizes
------------------------------------------------------------------
0 | net  | Sequential        | 17.4 K | [16, 62] | [16, 1]  
1 | loss | BCEWithLogitsLoss | 0      | ?        | ?        
------------------------------------------------------------------
17.4 K    Trainable params
0         Non-trainable params
17.4 K    Total params
0.070     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [41]:
model = GlycanModel.load_from_checkpoint("../../src/ml/lightning_logs/glycan/version_4/checkpoints/epoch=20-step=294.ckpt")
val_results = trainer.test(model, val_loader)
test_results = trainer.test(model, test_loader)

/Users/fubin/miniforge3/envs/torch/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]