## Import Libraries

In [18]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import copy
from glob import glob
import pandas as pd
from PIL import Image
import torchdata.datapipes as dp
import random
from torch.utils.data.backward_compatibility import worker_init_fn
from sklearn.metrics import classification_report
# import wandb
# import plotly.graph_objects as go

In [19]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import confusion_matrix
import plotly.graph_objects as go


## Dataloaders

In [20]:
"""
train_df = pd.read_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train_full.csv")
val_df = pd.read_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val_full.csv")

train_df = train_df[["Imaging_XENum","crop_filepath","label"]]
val_df= val_df[["Imaging_XENum","crop_filepath","label"]]

train_df.to_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train.csv")
val_df.to_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val.csv")
"""

'\ntrain_df = pd.read_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train_full.csv")\nval_df = pd.read_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val_full.csv")\n\ntrain_df = train_df[["Imaging_XENum","crop_filepath","label"]]\nval_df= val_df[["Imaging_XENum","crop_filepath","label"]]\n\ntrain_df.to_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train.csv")\nval_df.to_csv("/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val.csv")\n'

In [21]:
train_df = pd.read_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train_detected_full2_exp4_caa.csv")
val_df = pd.read_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val_detected_full2.csv")

train_df = train_df[["Imaging_XENum","crop_filepath","apoe_label"]]
val_df= val_df[["Imaging_XENum","crop_filepath","apoe_label"]]

train_df.to_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train_detected.csv")
val_df.to_csv("/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val_detected.csv")

In [22]:
train_df["crop_filepath"].iloc[0]

'/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/test-patients/images/XE10-033_1_AmyB_1/x_32768_y_74752.png'

In [23]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(1024),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]),
}

In [24]:
def open_image(inputs):
    _, wsi_name, img_path, label = inputs
    img = Image.open(img_path)
    return wsi_name, img, int(label)

def apply_train_transforms(inputs):
    _, x, y = inputs
    return data_transforms["train"](x), y

def apply_val_transforms(inputs):
    wsi_name, x, y = inputs
    return wsi_name, data_transforms["val"](x), y

def build_data_pipe(csv_file, transform , batch_size=32):
    new_dp = dp.iter.FileOpener([csv_file])
    new_dp = new_dp.parse_csv(skip_lines=1)
    # returns tuples like ('0','filename', 'filepath', 'label')
    if transform == "train":
        new_dp = new_dp.shuffle()
    
    new_dp = new_dp.sharding_filter()
    # important to use sharding_filter after (not before) shuffling -For the data source that needs to be sharded, it is crucial to add Shuffler before ShardingFilter to ensure data are globally shuffled before being split into shards. Otherwise, each worker process would always process the same shard of data for all epochs. And, it means each batch would only consist of data from the same shard, which leads to low accuracy during training. However, it doesn’t apply to the data source that has already been sharded for each multi-/distributed process, since ShardingFilter is no longer required to be presented in the pipeline.

    new_dp = new_dp.map(open_image)

    if transform == "train":
        new_dp = new_dp.map(apply_train_transforms)
        new_dp = new_dp.batch(batch_size=batch_size, drop_last=True)

    elif transform == "val":
        new_dp = new_dp.map(apply_val_transforms)
        new_dp = new_dp.batch(batch_size=batch_size, drop_last=False)

    else:
        raise ValueError("Invalid transform argument.")

    new_dp = new_dp.map(torch.utils.data.default_collate)
    return new_dp

In [25]:
#batch_size = 1
batch_size = 16
#TRAIN_CSV = "/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train.csv"
#VAL_CSV = "/gladstone/finkbeiner/steve//work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val.csv"

TRAIN_CSV = "/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/train_detected.csv"
VAL_CSV = "/mnt/new-nas/work/data/npsad_data/vivek/Datasets/amyb_wsi/apoe_training_dataset/val_detected.csv"


In [26]:
train_dp = build_data_pipe(TRAIN_CSV, "train", batch_size)
val_dp = build_data_pipe(VAL_CSV, "val", batch_size)

In [27]:
def dataset_size(csv_file):
    df = pd.read_csv(csv_file)
    return len(df)

train_datasize = dataset_size(TRAIN_CSV)
val_datasize = dataset_size(VAL_CSV)
dataset_sizes = {'train':train_datasize, 'val':val_datasize}

In [28]:
dataset_sizes

{'train': 1456, 'val': 4966}

In [29]:
train_loader = torch.utils.data.DataLoader(
    dataset=train_dp, shuffle=True, num_workers=4)

val_loader = torch.utils.data.DataLoader(
    dataset=val_dp, shuffle=False, num_workers=4)

dataloaders = {"train":train_loader, "val": val_loader}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Model Training

In [30]:
train_config = dict(
    epochs = 10,
    batch_size = 10,
    num_classes = 2,
    device_id = 0,
    eval_freq = 5,
)

In [31]:
model_config = dict(lr=0.001, momentum=0.9)

optim_config = dict(step_size=7, gamma=0.1)

In [32]:
def train_model( model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    log_metrics = list()

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for i, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs=inputs.squeeze()
                # print(inputs.shape)
                inputs=inputs.squeeze(0)
                labels = labels.squeeze()
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                #print(f'{phase} Loss: {loss.item():.4f} Batch No: {i}')
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            log_metrics.append(dict(epoch=epoch,phase=phase, loss=epoch_loss, metrics=epoch_acc))
            # deep copy the model
        
        #if (epoch+1)%train_config["eval_freq"]==0:
        #    test_model(model)
            
            #if phase == 'val' and epoch_acc > best_acc:
            ##if phase == 'val' and epoch+1==train_config["eval_freq"]:
                #test_model(model)
            #    best_acc = epoch_acc
            #    best_model_wts = copy.deepcopy(model.state_dict())
            
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 3600:.0f}h {time_elapsed % 60:.0f}m')
    #print(f'Best val Acc: {best_acc:4f}')
    #plot_training_curve(log_metrics)
    # load best model weights
    model.load_state_dict(best_model_wts)
    #run.log({"log":log_metrics})
    #torch.save({"model":model, "state": model.state_dict()}, '/gladstone/finkbeiner/steve/work/data/npsad_data/monika/LBD/WM_models/'+artifact_name+'.pth')
    #artifact = wandb.Artifact(artifact_name, type='files')
    #with artifact.new_file(f'ckpt/{epoch}.pt', 'wb') as f:
    #    torch.save(model.state_dict(), f)
    #run.log_artifact(artifact)
    #run.finish()
    return model, log_metrics

## RESNET-18 

In [33]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, train_config["num_classes"])

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
#optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)



In [34]:
model_ft, log_metrics = train_model( model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=train_config["epochs"])

Epoch 0/9
----------


train Loss: 0.3905 Acc: 0.7912
Epoch 1/9
----------
train Loss: 0.1900 Acc: 0.8949
Epoch 2/9
----------
train Loss: 0.1640 Acc: 0.9025
Epoch 3/9
----------
train Loss: 0.1473 Acc: 0.9073
Epoch 4/9
----------
train Loss: 0.0794 Acc: 0.9382
Epoch 5/9
----------
train Loss: 0.0781 Acc: 0.9402
Epoch 6/9
----------
train Loss: 0.0544 Acc: 0.9457
Epoch 7/9
----------
train Loss: 0.0519 Acc: 0.9505
Epoch 8/9
----------
train Loss: 0.0387 Acc: 0.9547
Epoch 9/9
----------
train Loss: 0.0464 Acc: 0.9512
Training complete in 0h 29m


In [35]:
torch.save({"model":model_ft, "state": model_ft.state_dict()}, '/mnt/new-nas/work/data/npsad_data/monika/Amy_plaque_Results/models/model_squished_label_4_detected.pth')

## Model Testing

In [36]:
path = '/mnt/new-nas/work/data/npsad_data/monika/Amy_plaque_Results/models/model_5ep_detected.pth'

In [37]:
def load_saved_model(path):
    checkpoint = torch.load(path)
    model_ft = checkpoint["model"]
    model_ft.load_state_dict(checkpoint['state'])
    return model_ft

In [38]:
model_ft = load_saved_model(path)

In [39]:
def test_model(model):
    was_training = model.training
    model.eval()
    actual_labels = []
    pred_labels = []
    wsi_names = []
    scores=[]
    with torch.no_grad():
        for i, (wsi, inputs, labels) in enumerate(dataloaders['val']):
            #print(inputs.shape)
            inputs=inputs.squeeze()
            labels = labels.squeeze()
            inputs = inputs.to(device)
            labels = labels.to(device)
            #print(labels)
            actual_labels.extend(labels.tolist())
            outputs = model(inputs)
            scores.extend(outputs.cpu().tolist())
            _, preds = torch.max(outputs, 1)
            pred_labels.extend(preds.tolist())
            wsi=[x[0] for x in wsi]
            wsi_names.extend(wsi)
            output_df = pd.DataFrame({"wsi_name":wsi_names,"actual_labels":actual_labels,"pred_labels":pred_labels,"scores":scores})
            if i%500==0:
                print(i, "Done")
        classes = [0,1]
        eval_metrics = pd.DataFrame(columns=["WSI","accu_score","precision-control","recall-control","f1-score-control","support-control", "precision-APOE","recall-APOE","f1-score-APOE","support-APOE"])
        # for wsi in output_df['wsi_name'].unique():
        #     tmp = output_df[output_df["wsi_name"]==wsi]  
        #     acc_score = accuracy_score(tmp["actual_labels"], tmp["pred_labels"])
        #     prec_rec = precision_recall_fscore_support(tmp["actual_labels"], tmp["pred_labels"], labels=[0,1])
        #     # eval_metrics.loc[ind] = (wsi, acc_score, prec_rec[0][0],prec_rec[1][0],prec_rec[2][0], prec_rec[0][1],prec_rec[1][1],prec_rec[2][1], prec_rec[0][2],prec_rec[1][2],prec_rec[2][2])
        #     #plot_roc_auc_curve(run, tmp["actual_labels"], tmp["scores"], classes, wsi)
        #     #plot_pr_curve(run, tmp["actual_labels"], tmp["scores"], classes, wsi)
        #     ind=ind+1
        #     #wandb.summary["Evaluation Metric for WSI :" + wsi]=acc_score
        
        #tbl = wandb.Table(data=eval_metrics)
        #run.log({"Test Evaluation Metric": tbl})
        #eval_metrics.to_csv(eval_path + "eval_metric.csv")
        # print(eval_metrics)
    return output_df

In [40]:
output_df = test_model(model_ft)

0 Done


RuntimeError: Caught RuntimeError in DataLoader worker process 1.
Original Traceback (most recent call last):
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
    data.append(next(self.dataset_iter))
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 144, in __next__
    return self._get_next()
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 132, in _get_next
    result = next(self.iterator)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 215, in wrap_next
    result = next_func(*args, **kwargs)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/datapipes/datapipe.py", line 369, in __next__
    return next(self._datapipe_iter)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 185, in wrap_generator
    response = gen.send(request)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/datapipes/iter/callable.py", line 123, in __iter__
    yield self._apply_fn(data)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/datapipes/iter/callable.py", line 88, in _apply_fn
    return self.fn(data)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 142, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 142, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/home/vivek/.virtualenvs/mask_rcnn-w0kK5vJa/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 162, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 1024, 1024] at entry 0 and [3, 4161, 1024] at entry 10
This exception is thrown by __iter__ of MapperIterDataPipe()


In [None]:
output_df

Unnamed: 0,wsi_name,actual_labels,pred_labels,scores,correct
0,XE16-033,1,0,"[0.7342130541801453, -0.30953988432884216, -0....",False
1,XE16-033,1,0,"[0.39095649123191833, -0.3952394425868988, -0....",False
2,XE16-033,1,0,"[0.7152571678161621, -0.23949187994003296, -0....",False
3,XE16-033,1,0,"[0.9912577867507935, -0.4821361303329468, -0.2...",False
4,XE16-033,1,0,"[0.8842447400093079, -0.4726897180080414, -0.3...",False
...,...,...,...,...,...
1194,XE14-004,0,0,"[0.555978536605835, 0.005207183305174112, -0.3...",True
1195,XE14-004,0,0,"[0.5898568034172058, -0.3283420503139496, -0.4...",True
1196,XE14-004,0,0,"[0.6234608292579651, -0.195412740111351, -0.27...",True
1197,XE14-004,0,0,"[0.454772412776947, -0.09624777734279633, -0.2...",True


In [None]:
(output_df["actual_labels"]==output_df["pred_labels"]).sum()/len(output_df)

0.11426188490408674

In [None]:
output_df["correct"] = output_df["actual_labels"]==output_df["pred_labels"]

In [None]:
output_df.groupby(["wsi_name"])["correct"].sum()/output_df.groupby(["wsi_name"])["correct"].count()

wsi_name
XE08-017    1.0
XE11-027    0.0
XE14-004    1.0
XE16-033    0.0
Name: correct, dtype: float64

In [None]:
for wsi in output_df["wsi_name"].unique():
    print("-----Confusion Matrix for ", wsi, "-------")
    actual = output_df[output_df["wsi_name"]==wsi]["actual_labels"]
    pred = output_df[output_df["wsi_name"]==wsi]["pred_labels"]
    print(confusion_matrix(actual, pred))

-----Confusion Matrix for  XE16-033 -------
[[  0   0]
 [428   0]]
-----Confusion Matrix for  XE11-027 -------
[[  0   0]
 [634   0]]
-----Confusion Matrix for  XE08-017 -------
[[10]]
-----Confusion Matrix for  XE14-004 -------
[[127]]
