# Train the CNNs as seen in "Predicting the HER2 status in esophageal cancer from tissue microarrays using convolutional neural networks".

### Dependencies

In [3]:
%git clone https://github.com/bozeklab/HER2-overexpression.git
%cd ./HER2-overexpression

import os
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.nn import CrossEntropyLoss
from torchvision.models import resnet34
from torchvision import transforms
import pandas as pd

from src.utils import calculate_weights
from src.dataset import DistributedWeightedSampler, ImageDataset
from src.models import ResnetABMIL
from train import train_step, validate_step

### Hyperparams (see repo's readme for more info)

In [4]:
task = 'her2-status'
model = 'resnet34'
img_size = 1024
train_csv = './train.csv'
val_csv = None
weighted_sampler_label = None
batch_size = 1
num_workers = 0
learning_rate = 0.001
epochs = 100
checkpoints_dir = './checkpoints'
scheduler_factor = 0.1
scheduler_patience = 10

### Train!

In [None]:
os.makedirs(checkpoints_dir, exist_ok = True)
torch.manual_seed(0)
if task == 'ihc-score':
    num_classes = 4
elif task == 'her2-status':
    num_classes = 2
else:
    raise ValueError('Task should be ihc-score or her-status')

if model == 'resnet34':
    model = resnet34(pretrained = False, num_classes = num_classes).cuda()
elif model == 'abmil':
    model = ResnetABMIL(num_classes = num_classes).cuda()
else:
    raise ValueError('Model should be resnet34 or abmil')  

train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
])

train_df = pd.read_csv(train_csv)
train_dataset = ImageDataset(train_df, fn_col = 'filename', lbl_col = task, transform = train_transform)
if weighted_sampler_label == None:
    weighted_sampler_label = task
weights = calculate_weights(torch.tensor(train_df[weighted_sampler_label].values))
train_sampler = WeightedRandomSampler(weights, len(weights))
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=train_sampler)

if val_csv != None:
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])
    val_df = pd.read_csv(val_csv)
    val_dataset = ImageDataset(val_df, fn_col = 'filename', lbl_col = task, transform = val_transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=scheduler_factor, patience=scheduler_patience, min_lr=1e-15)

criterion = CrossEntropyLoss()

epoch0 = 0
epoch = epoch0
while epoch < epoch0 + epochs:

    train_phase_results = train_step(train_loader, model, criterion, optimizer)
    val_phase_results = {'Loss': '', 'Accuracy' : ''} 
    if val_csv != None:
        val_phase_results = validate_step(val_loader, model, criterion)
        acc = val_phase_results['Accuracy']
        scheduler.step(acc)

    print('Epoch {} finished.'.format(epoch))
    print('Train phase: ', train_phase_results)
    print('Val phase: ', val_phase_results)
    print('\n')

    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'accuracy': val_phase_results['Accuracy']

    }, os.path.join(checkpoints_dir,'checkpoint_{}.pth.tar'.format(epoch)))
    epoch += 1