In [5]:
import pandas as pd
import plotly.express as px

In [26]:
folder = "/srv/scratch/AMR/Reduced_genotype"
file = f"{folder}/Campylobacter_jejuni_reduced_genotype.tsv"
x = pd.read_csv(file, sep="\t", index_col=0)
y = pd.read_csv("/srv/scratch/AMR/IR_phenotype/Campylobacter_jejuni/phenotype.txt", sep="\t", index_col=0)
y = y.loc[x.index]

In [30]:
y

Unnamed: 0,TET
01M0ORJA,1
04YTLP0K,1
07TMJS8Q,1
0EM84OF9,0
0ITGCEEM,0
...,...
ZXY2SJZL,1
ZYFBKN1V,0
ZYI24BNW,1
ZYPNSHM4,1


In [14]:
import torch
import lightning as L



IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



In [52]:
class MyModel(L.LightningModule):
    def __init__(self, n_feats:int, dropout:float = 0.5):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(n_feats, 128), # layer 1
            torch.nn.ReLU(), # activation function
            torch.nn.Dropout(dropout), # dropout for regularization
            torch.nn.Linear(128, 64), 
            torch.nn.ReLU(), 
            torch.nn.Dropout(dropout),
            torch.nn.Linear(64, 1)
        )

    def forward(self, x:torch.Tensor):
        return self.mlp(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [53]:
x_train = x.sample(frac=0.8, random_state=42)
y_train = y.loc[x_train.index]
x_val = x.drop(x_train.index)
y_val = y.loc[x_val.index]
train_dataloader = torch.utils.data.DataLoader(
    list(
        zip(
            torch.tensor(x_train.values, dtype=torch.float32),
            torch.tensor(y_train.values, dtype=torch.float32),
        )
    ),
    batch_size=32,
    shuffle=True,
)
val_dataloader = torch.utils.data.DataLoader(
    list(
        zip(
            torch.tensor(x_val.values, dtype=torch.float32),
            torch.tensor(y_val.values, dtype=torch.int32),
        )
    ),
    batch_size=32,
    shuffle=False,
)

In [54]:
batch = next(iter(train_dataloader))
x_, y_ = batch
y_

tensor([[0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.]])

In [55]:
model = MyModel(n_feats=x.shape[1])
trainer = L.Trainer(max_epochs=10, accelerator="cpu")
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ilya/.conda/envs/esm/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | mlp  | Sequential | 22.1 K | train
--------------------------------------------
22.1 K    Trainable params
0         Non-trainable params
22.1 K    Total params
0.089     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
