In [23]:
import sys
sys.path.append('/home/colin/covid-blood/')
from config import get_config
from torchvision import transforms
from utils import setup_torch, get_covid_transforms, load_model
import wandb
from dataloader import load_all_patients, load_pbc_data
from models.imagenet import get_model
from models.multi_instance import AttentionModel, GatedAttentionModel, SimpleMIL
from mil_trainer import ClassificationTrainer
from torch import optim
import warnings
from tqdm import tqdm
import json
from torchvision.datasets import ImageFolder
import os
import torch
import matplotlib.pyplot as plt

In [22]:
def create_mega_fold(run_ids, root_dir='/home/colin/results_cov/', control=False):
    mega_fold = {}
    if control:
        root_dir = os.path.join(root_dir, 'control')
    for run_id in run_ids:
        files = glob.glob(os.path.join(root_dir, f'*{run_id}*.json'))
        assert len(files) == 1
        file = files[0]
        with open(file) as fp:
            all_data = json.load(fp)        
            test_fold = {patient: data for patient, data in all_data.items() if 'predictions' in data}
            for patient, data in test_fold.items():
                if patient in mega_fold:
                    mega_fold[patient]['predictions'] += data['predictions']
                else:
                    mega_fold[patient] = data
    return mega_fold

def denormalize_image(image, mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]):
    image = image.clone()
    image = image * torch.tensor(std).view(3, 1, 1)
    image = image + torch.tensor(mean).view(3, 1, 1)
    return image.permute(1, 2, 0)

In [10]:
setup_torch(0, 1, 1)

In [2]:
data_transforms = get_covid_transforms(image_size=224, center_crop_amount=224)

0.5 blank


In [44]:
batch_size=64
data_dir = '/home/colin/filtered/'
data_transforms = {
    'train': data_transforms['train'],
    'val': data_transforms['val']
}
# luckily torchvision has a nice class for this scenario
# Create training and validation datasets
image_datasets = {x: ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
train_loader, val_loader = [torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for
                            x in ['train', 'val']]

In [45]:
classes = ['basophil',
 'eosinophil',
 'erythroblast',
 'garbage',
 'ig',
 'lymphocyte',
 'monocyte',
 'neutrophil',
 'platelet']

In [46]:
wbc_model = get_model('resnet50', len(classes), True).cuda()

In [47]:
optimizer = optim.AdamW(wbc_model.parameters(), lr=3e-4)

In [48]:
wbc_model.train()
pass

In [49]:
loss = torch.nn.CrossEntropyLoss()

In [50]:
epochs = 10
for e in range(epochs):
    total_loss = 0
    total_acc = 0
    print("STARTING EPOCH: ", e)
    wbc_model.train()
    for images, labels in tqdm(train_loader):
        images, labels = images.cuda(), labels.cuda()
        output = wbc_model(images)
        _, preds = torch.max(output, 1)
        acc = torch.sum(preds == labels.data).float() / len(labels)
        loss_val = loss(output, labels)
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
        total_loss += float(loss_val)
        total_acc += float(acc)
    print("TRAIN ACC: ", total_acc/len(train_loader))
    print("TRAIN LOSS: ", total_loss/len(train_loader))
    wbc_model.eval()
    total_loss = 0
    total_acc = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images, labels = images.cuda(), labels.cuda()
            output = wbc_model(images)
            _, preds = torch.max(output, 1)
            acc = torch.sum(preds == labels.data).float() / len(labels)
            loss_val = loss(output, labels)
            total_loss += float(loss_val)
            total_acc += float(acc)
    print("VAL ACC: ", total_acc/len(val_loader))
    print("VAL LOSS: ", total_loss/len(val_loader))

  0%|          | 0/249 [00:00<?, ?it/s]

STARTING EPOCH:  0


100%|██████████| 249/249 [04:58<00:00,  1.20s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.7462514531181519
TRAIN LOSS:  1.0384605268397964


100%|██████████| 45/45 [00:12<00:00,  3.68it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.8213998542891608
VAL LOSS:  0.7446866101688809
STARTING EPOCH:  1


100%|██████████| 249/249 [04:46<00:00,  1.15s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.8800597126225391
TRAIN LOSS:  0.5055736068740906


100%|██████████| 45/45 [00:11<00:00,  4.01it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.8602887431780497
VAL LOSS:  0.5310535828272501
STARTING EPOCH:  2


100%|██████████| 249/249 [04:48<00:00,  1.16s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.8955592106145069
TRAIN LOSS:  0.39304940809447125


100%|██████████| 45/45 [00:11<00:00,  4.05it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.8837719294759963
VAL LOSS:  0.4411036749680837
STARTING EPOCH:  3


100%|██████████| 249/249 [04:49<00:00,  1.16s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.904426918211711
TRAIN LOSS:  0.3392461881819499


100%|██████████| 45/45 [00:11<00:00,  4.02it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.8955774850315518
VAL LOSS:  0.3759029832151201
STARTING EPOCH:  4


100%|██████████| 249/249 [04:49<00:00,  1.16s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.9083538100901378
TRAIN LOSS:  0.3126192086312666


100%|██████████| 45/45 [00:10<00:00,  4.09it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.8893274850315518
VAL LOSS:  0.3726087427801556
STARTING EPOCH:  5


100%|██████████| 249/249 [04:49<00:00,  1.16s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.9164751109827953
TRAIN LOSS:  0.2866353825273284


100%|██████████| 45/45 [00:11<00:00,  4.02it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.9139802628093295
VAL LOSS:  0.3160105721818076
STARTING EPOCH:  6


100%|██████████| 249/249 [04:48<00:00,  1.16s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.9174791270470524
TRAIN LOSS:  0.2702555116041597


100%|██████████| 45/45 [00:11<00:00,  4.04it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.9118969294759962
VAL LOSS:  0.31129692097504935
STARTING EPOCH:  7


100%|██████████| 249/249 [04:53<00:00,  1.18s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.9226676443494467
TRAIN LOSS:  0.25664085329297076


100%|██████████| 45/45 [00:11<00:00,  3.99it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.9073830405871074
VAL LOSS:  0.30459754632578956
STARTING EPOCH:  8


100%|██████████| 249/249 [04:56<00:00,  1.19s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.9217230766173826
TRAIN LOSS:  0.25345513855956164


100%|██████████| 45/45 [00:11<00:00,  3.97it/s]
  0%|          | 0/249 [00:00<?, ?it/s]

VAL ACC:  0.9170138888888889
VAL LOSS:  0.2719445832901531
STARTING EPOCH:  9


100%|██████████| 249/249 [04:56<00:00,  1.19s/it]
  0%|          | 0/45 [00:00<?, ?it/s]

TRAIN ACC:  0.927813226678764
TRAIN LOSS:  0.23585226470567136


100%|██████████| 45/45 [00:11<00:00,  4.00it/s]

VAL ACC:  0.9203581876224942
VAL LOSS:  0.2638611737224791





In [51]:
model_path = '/home/colin/wbc_model_2021.pth'
torch.save(wbc_model.state_dict(), model_path)