In [None]:
import pandas as pd
import os

if not os.path.isfile('IFCB.csv.zip'):
    print("CSV data do not exist. Downloading...")
    !wget -O IFCB.csv.zip "https://unioviedo-my.sharepoint.com/:u:/g/personal/gonzalezgpablo_uniovi_es/EfsVLhFsYJpPjO0KZlpWUq0BU6LaqJ989Re4XzatS9aG4Q?download=1"

data = pd.read_csv('IFCB.csv.zip',compression='infer', header=0,sep=',',quotechar='"')

#Compute sample and year information
data['year'] = data['Sample'].str[6:10].astype(str) #Compute the year
samples=data.groupby('Sample').first()
samples = samples[["year"]]
print(data)

In [None]:
from tqdm import tqdm
import numpy as np

tqdm.pandas()

classcolumn = "AutoClass" #AutoClass means 51 classes
yearstraining = ['2006','2007'] #Years to consider as training
yearstest = ['2008'] #Years to consider as test

samplestraining = list(samples[samples['year'].isin(yearstraining)].index) #Samples to consider for training
samplestest = list(samples[samples['year'].isin(yearstest)].index) #Samples to consider for testing

classes=np.unique(data[classcolumn])
classes.sort()

In [None]:
from h5ifcbdataset import H5IFCBDataset
import os

hdf5_files_path = "/media/nas/pgonzalez/IFCB_HDF5/output/"

#files to load
filestraining = [hdf5_files_path+s+'.hdf5' for s in samplestraining]
filestest = [hdf5_files_path+s+'.hdf5' for s in samplestest]

#check if file exists
if not os.path.isfile('training.pkl'):
  train_dset = H5IFCBDataset(filestraining,classes,classattribute=classcolumn, verbose=1,trainingset=False)
  train_dset.save("training.pkl")
else:
  train_dset = H5IFCBDataset([],classes,classattribute=classcolumn, verbose=1,trainingset=False)
  train_dset.load("training.pkl")


In [None]:
from lightly.models.modules.heads import SimCLRProjectionHead
from lightly.loss import NTXentLoss
import pytorch_lightning as pl
import torch.nn as nn

class SimCLRModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        hidden_dim = resnet.fc.in_features
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128)

        self.criterion = NTXentLoss()

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        self.log("train_loss_ssl", loss)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optim, max_epochs
        )
        return [optim], [scheduler]

In [None]:
import torch
percentage = 0.1
percentage = round(percentage*len(train_dset))
train_labeled, train_unlabeled = torch.utils.data.random_split(train_dset, [percentage, len(train_dset)-percentage],generator=torch.Generator().manual_seed(42))
print(len(train_labeled))
print(len(train_unlabeled))

In [None]:
num_workers = 8
batch_size = 256
seed = 1
max_epochs = 100
input_size = 64

pl.seed_everything(seed)

In [None]:
import lightly
import torchvision

gpus = 2 if torch.cuda.is_available() else 0
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=input_size,
    vf_prob=0.5,
    rr_prob=0.5,
    cj_prob=0.0,
    random_gray_scale=0.0
)

dataloader_train_simclr = torch.utils.data.DataLoader(
    train_unlabeled,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

model = SimCLRModel()
trainer = pl.Trainer(
    strategy="dp",max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100
)
trainer.fit(model, dataloader_train_simclr)

pretrained_resnet_backbone = model.backbone

# you can also store the backbone and use it in another code
state_dict = {
    'resnet18_parameters': pretrained_resnet_backbone.state_dict()
}
torch.save(state_dict, 'model.pth')