In [None]:
import pandas as pd
import os

if not os.path.isfile('IFCB.csv.zip'):
    print("CSV data do not exist. Downloading...")
    !wget -O IFCB.csv.zip "https://unioviedo-my.sharepoint.com/:u:/g/personal/gonzalezgpablo_uniovi_es/EfsVLhFsYJpPjO0KZlpWUq0BU6LaqJ989Re4XzatS9aG4Q?download=1"

data = pd.read_csv('IFCB.csv.zip',compression='infer', header=0,sep=',',quotechar='"')

#Compute sample and year information
data['year'] = data['Sample'].str[6:10].astype(str) #Compute the year
samples=data.groupby('Sample').first()
samples = samples[["year"]]
print(data)

In [None]:
from tqdm import tqdm
import numpy as np

tqdm.pandas()

classcolumn = "AutoClass" #AutoClass means 51 classes
yearstraining = ['2006','2007'] #Years to consider as training
yearstest = ['2008'] #Years to consider as test

samplestraining = list(samples[samples['year'].isin(yearstraining)].index) #Samples to consider for training
samplestest = list(samples[samples['year'].isin(yearstest)].index) #Samples to consider for testing

classes=np.unique(data[classcolumn])
classes.sort()

In [None]:
import torchvision.transforms as T
from h5ifcbdataset import H5IFCBDataset
import torchvision

hdf5_files_path = "/media/nas/pgonzalez/IFCB_HDF5/output/"

#files to load
filestraining = [hdf5_files_path+s+'.hdf5' for s in samplestraining]
filestest = [hdf5_files_path+s+'.hdf5' for s in samplestest]

#Define transofrmations
train_transform = T.Compose([
  T.Resize(size=64),
  T.RandomResizedCrop(size=64),
  T.RandomHorizontalFlip(),
  T.ToTensor()
])

val_transform = T.Compose([
  T.Resize(size=64),
  T.CenterCrop(size=64),
  T.ToTensor()
])

#check if file exists
if not os.path.isfile('training.pkl') and not os.path.isfile("validation.pkl"):
  train_dset = H5IFCBDataset(filestraining,classes,classattribute=classcolumn, verbose=1,trainingset=False,transform=train_transform)
  train_dset.save("training.pkl")
  val_dset = H5IFCBDataset(filestest,classes,classattribute=classcolumn, verbose=1,trainingset=False,transform=val_transform)
  val_dset.save("validation.pkl")
else:
  train_dset = H5IFCBDataset([],classes,classattribute=classcolumn, verbose=1,trainingset=False,transform=train_transform)
  train_dset.load("training.pkl")
  val_dset = H5IFCBDataset([],classes,classattribute=classcolumn, verbose=1,trainingset=False,transform=val_transform)
  val_dset.load("validation.pkl")

In [None]:
import torch
percentage = 0.1
percentage = round(percentage*len(train_dset))
train_labeled, train_unlabeled = torch.utils.data.random_split(train_dset, [percentage, len(train_dset)-percentage],generator=torch.Generator().manual_seed(42))
print(len(train_labeled))
print(len(train_unlabeled))
print(len(val_dset))

In [None]:
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class IFCBDataset(pl.LightningDataModule):
    def __init__(self, batch_size, train_dset,val_dset):
        super().__init__()
        self.train_dset = train_dset
        self.val_dset = val_dset
        self.batch_size = batch_size
        self.num_classes = 51

        
    def train_dataloader(self):
        return DataLoader(self.train_dset, batch_size=self.batch_size, shuffle=True,num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_dset, batch_size=self.batch_size,num_workers=8)

#Get only a subset of validation for speed up
reduced_val_dset,_ = torch.utils.data.random_split(val_dset, [20000,len(val_dset)-20000],generator=torch.Generator().manual_seed(42))
print(len(reduced_val_dset))
finetuning_ds = IFCBDataset(512,train_labeled,reduced_val_dset)

In [None]:
import torchvision.models as models
from torch.nn.functional import cross_entropy
import torchmetrics
import pytorch_lightning as pl
import torch.nn as nn

from torch.optim import Adam

class IFCBFineTuning(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        self.model = models.resnet18(pretrained=True)
        
        #COMENTAR ESTAS DOS LINEAS PARA HACER SIMPLEMENTE FINETUNING SIN CARGAR LOS PESOS APRENDIDOS CON SIMCLR
        ckpt = torch.load('model.pth')
        self.model.load_state_dict(ckpt['resnet18_parameters'],strict=False)
        
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, 51)
        # for param in self.model.parameters():
        #     param.requires_grad = True
        self.valid_acc = torchmetrics.Accuracy()
        

    def training_step(self, batch, batch_idx):
        # return the loss given a batch: this has a computational graph attached to it: optimization
        x, y,_ = batch
        preds = self.model(x)
        loss = cross_entropy(preds, y)
        self.log('train_loss', loss)  # lightning detaches your loss graph and uses its value
        return loss

    

    #validation step
    def validation_step(self, batch, batch_idx):
        x, y, _= batch
        y_hat = self.model(x)
        self.valid_acc(y_hat.cpu(), y.cpu())
        self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True,prog_bar=True)

    def configure_optimizers(self):
        # return optimizer
        optimizer = Adam(self.model.parameters(), lr=1e-4)
        return optimizer

finetuning_model = IFCBFineTuning()

In [None]:
gpus = 2 if torch.cuda.is_available() else 0
trainer = pl.Trainer(strategy="dp",gpus=gpus, max_epochs=50)
trainer.fit(finetuning_model, finetuning_ds)