## Load Libraries

In [10]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import copy
from glob import glob
import pandas as pd
from PIL import Image
import torchdata.datapipes as dp
import random
from torch.utils.data.backward_compatibility import worker_init_fn
from sklearn.metrics import classification_report


In [13]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import confusion_matrix
import plotly.graph_objects as go


## Load Data

In [15]:
train_df = pd.read_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train_detected_full2.csv")
val_df = pd.read_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val_detected_full2.csv")

train_df = train_df[["Imaging_XENum","crop_filepath","apoe_label"]]
val_df= val_df[["Imaging_XENum","crop_filepath","apoe_label"]]

train_df.to_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train.csv")
val_df.to_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val.csv")

## Dataloaders

In [16]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(1024),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]),
}

In [18]:
def open_image(inputs):
    _, wsi_name, img_path, label = inputs
    img = Image.open(img_path)
    return wsi_name, img, int(label)

def apply_train_transforms(inputs):
    _, x, y = inputs
    return data_transforms["train"](x), y

def apply_val_transforms(inputs):
    wsi_name, x, y = inputs
    return wsi_name, data_transforms["val"](x), y

def build_data_pipe(csv_file, transform , batch_size=32):
    new_dp = dp.iter.FileOpener([csv_file])
    new_dp = new_dp.parse_csv(skip_lines=1)
    # returns tuples like ('0','filename', 'filepath', 'label')
    if transform == "train":
        new_dp = new_dp.shuffle()
    
    new_dp = new_dp.sharding_filter()
    # important to use sharding_filter after (not before) shuffling -For the data source that needs to be sharded, it is crucial to add Shuffler before ShardingFilter to ensure data are globally shuffled before being split into shards. Otherwise, each worker process would always process the same shard of data for all epochs. And, it means each batch would only consist of data from the same shard, which leads to low accuracy during training. However, it doesn’t apply to the data source that has already been sharded for each multi-/distributed process, since ShardingFilter is no longer required to be presented in the pipeline.

    new_dp = new_dp.map(open_image)

    if transform == "train":
        new_dp = new_dp.map(apply_train_transforms)
        new_dp = new_dp.batch(batch_size=batch_size, drop_last=True)

    elif transform == "val":
        new_dp = new_dp.map(apply_val_transforms)
        new_dp = new_dp.batch(batch_size=batch_size, drop_last=False)

    else:
        raise ValueError("Invalid transform argument.")

    new_dp = new_dp.map(torch.utils.data.default_collate)
    return new_dp

def dataset_size(csv_file):
    df = pd.read_csv(csv_file)
    return len(df)

In [32]:
batch_size = 1

TRAIN_CSV = "/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val.csv"
VAL_CSV = "/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val.csv"

In [33]:
train_dp = build_data_pipe(TRAIN_CSV, "train", batch_size)
val_dp = build_data_pipe(VAL_CSV, "val", batch_size)

In [34]:
train_datasize = dataset_size(TRAIN_CSV)
val_datasize = dataset_size(VAL_CSV)
dataset_sizes = {'train':train_datasize, 'val':val_datasize}

In [35]:
train_loader = torch.utils.data.DataLoader(
    dataset=train_dp, shuffle=True, num_workers=4)

val_loader = torch.utils.data.DataLoader(
    dataset=val_dp, shuffle=False, num_workers=4)

dataloaders = {"train":train_loader, "val": val_loader}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Train Dino-v2 Model

In [36]:
train_config = dict(
    epochs = 5,
    batch_size = 6,
    num_classes = 3,
    device_id = 0,
    eval_freq = 5,
)

In [None]:
model_config = dict(lr=0.001, momentum=0.9)

optim_config = dict(step_size=7, gamma=0.1)

: 

In [29]:
def train_model( model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    log_metrics = list()

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for i, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs=inputs.squeeze()
                #print(inputs.shape)
                #inputs=inputs.squeeze(0)
                labels = labels.squeeze()
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                #print(f'{phase} Loss: {loss.item():.4f} Batch No: {i}')
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            log_metrics.append(dict(epoch=epoch,phase=phase, loss=epoch_loss, metrics=epoch_acc))
            # deep copy the model
        
        #if (epoch+1)%train_config["eval_freq"]==0:
        #    test_model(model)
            
            #if phase == 'val' and epoch_acc > best_acc:
            ##if phase == 'val' and epoch+1==train_config["eval_freq"]:
                #test_model(model)
            #    best_acc = epoch_acc
            #    best_model_wts = copy.deepcopy(model.state_dict())
            
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 3600:.0f}h {time_elapsed % 60:.0f}m')
    #print(f'Best val Acc: {best_acc:4f}')
    #plot_training_curve(log_metrics)
    # load best model weights
    model.load_state_dict(best_model_wts)
    #run.log({"log":log_metrics})
    #torch.save({"model":model, "state": model.state_dict()}, '/gladstone/finkbeiner/steve/work/data/npsad_data/monika/LBD/WM_models/'+artifact_name+'.pth')
    #artifact = wandb.Artifact(artifact_name, type='files')
    #with artifact.new_file(f'ckpt/{epoch}.pt', 'wb') as f:
    #    torch.save(model.state_dict(), f)
    #run.log_artifact(artifact)
    #run.finish()
    return model, log_metrics

In [31]:
# load dino model
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

class DinoVisionTransformerClassifier(nn.Module):
    def __init__(self):
        super(DinoVisionTransformerClassifier, self).__init__()
        self.transformer = dinov2_vits14
        self.classifier = nn.Sequential(
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Linear(256, 3)
        )
    
    def forward(self, x):
        x = self.transformer(x)
        x = self.transformer.norm(x)
        x = self.classifier(x)
        return x

model_dino = DinoVisionTransformerClassifier()

model_dino = model_dino.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_dino.parameters(), lr=0.001, momentum=0.9)
#optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_dino, log_metrics = train_model( model_dino, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=train_config["epochs"])

Using cache found in /home/vivek/.cache/torch/hub/facebookresearch_dinov2_main


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
torch.save({"model":model_dino, "state": model_dino.state_dict()}, '/gladstone/finkbeiner/steve/work/data/npsad_data/monika/Amy_plaque_Results/models/dino_model_5ep_detected.pth')

: 

: 