In [None]:
import os
import re
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import pytorch_lightning as pl
import wandb
from timm import create_model
from einops import rearrange, reduce
from sklearn.metrics import accuracy_score, f1_score
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torchvision.transforms import ColorJitter, Normalize, Compose, RandomAffine

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from multiplex_imaging_pipeline.region_classification import RegionImgTransform, RegionDataset, RegionClassifier, ModelLightning


In [None]:
torch.cuda.is_available()

In [None]:
train_ds = torch.load('/data/estorrs/multiplex_data/analysis/dcis_region_analysis/classifier_v4/datasets/train.pt')
val_ds = torch.load('/data/estorrs/multiplex_data/analysis/dcis_region_analysis/classifier_v4/datasets/val.pt')
pred_ds = torch.load('/data/estorrs/multiplex_data/analysis/dcis_region_analysis/classifier_v4/datasets/prediction.pt')

In [None]:
batch_size = 32
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=10)
val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=10)
pred_dl = DataLoader(pred_ds, batch_size=batch_size, num_workers=10)

In [None]:
d = train_ds[0]
d.keys()

In [None]:
d['rgb'].dtype, d['mask'].dtype, d['y'].shape

In [None]:
img = rearrange(d['rgb'], 'c h w -> h w c')
img -= img.min()
img /= img.max()
plt.imshow(img)

In [None]:
plt.imshow(d['mask'][0])

In [None]:
d['y']

In [None]:
b = next(iter(train_dl))
b.keys()

In [None]:
b['rgb'].dtype, b['mask'].dtype

In [None]:
n_classes, n_channels = b['y'].shape[1], b['rgb'].shape[1]

model = RegionClassifier(n_classes, n_channels=n_channels)
out = model(b['rgb'], b['mask'])
out.shape

In [None]:
model.calculate_loss(out, b['y'])

In [None]:
class LoggingCallback(pl.Callback):
    def __init__(self, log_every=10, log_n_samples=8):
        self.log_every = log_every
        self.log_n_samples = log_n_samples
    
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if trainer.current_epoch % self.log_every == 0 and batch_idx==0:
            img = batch['rgb'][:self.log_n_samples].clone().detach().cpu()
            img -= img.min()
            img /= img.max()
            logger.log_image(
                key=f"train/rgb",
                images=[i[0] if i.shape[0] not in [1, 3] else i for i in img],
                caption=[train_ds.labels[i]
                         for i in outputs['probs'][:self.log_n_samples].argmax(dim=-1).clone().detach().cpu().numpy()]
            )
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        if trainer.current_epoch % self.log_every == 0 and batch_idx==0:
            img = batch['rgb'][:self.log_n_samples].clone().detach().cpu()
            img -= img.min()
            img /= img.max()
            logger.log_image(
                key=f"val/rgb",
                images=[i[0] if i.shape[0] not in [1, 3] else i for i in img],
                caption=[train_ds.labels[i]
                         for i in outputs['probs'][:self.log_n_samples].argmax(dim=-1).clone().detach().cpu().numpy()]
            )


## training model

In [None]:
project = 'region_classifier'
log_dir = '/data/estorrs/multiplex_data/analysis/dcis_region_analysis/classifier_v4/logs'
Path(log_dir).mkdir(parents=True, exist_ok=True)

In [None]:
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project=project, save_dir=log_dir)

In [None]:
# wandb.finish()

In [None]:
config = {
    'n_classes': train_ds.y.shape[1],
    'n_features': 100,
    'backbone': 'resnet34',
    'n_channels': train_ds[0]['rgb'].shape[0],
    'training': {
        'train_samples': len(train_ds),
        'val_samples': len(val_ds),
        'log_n_samples': 16,
        'max_epochs': 500,
        'check_val_every_n_epoch': 1,
        'log_every': 1,
        'accelerator': 'gpu',
        'devices': [1],
        'lr': 2e-4,
        'batch_size': batch_size,
        'precision': 32
    },
}
logger.experiment.config.update(config)

In [None]:
m = RegionClassifier(
    config['n_classes'],
    n_channels=config['n_channels'],
    backbone=config['backbone']
)
model = ModelLightning(m, lr=config['training']['lr'])

In [None]:
# !mkdir /data/estorrs/multiplex_data/analysis/dcis_region_analysis/classifier_v3/ckpts

In [None]:
trainer = pl.Trainer(
    callbacks=[
        LoggingCallback(
            log_every=config['training']['log_every'],
            log_n_samples=config['training']['log_n_samples']
        ),
        ModelCheckpoint(
            dirpath="/data/estorrs/multiplex_data/analysis/dcis_region_analysis/classifier_v3/ckpts",
            save_top_k=5, monitor="val/loss"
        )
    ],
    devices=config['training']['devices'],
    accelerator=config['training']['accelerator'],
    check_val_every_n_epoch=config['training']['check_val_every_n_epoch'],
    enable_checkpointing=True,
    max_epochs=config['training']['max_epochs'],
    precision=config['training']['precision'],
    logger=logger
)

In [None]:
trainer.fit(model=model, train_dataloaders=train_dl, val_dataloaders=val_dl)

In [None]:
cb = [c for c in trainer.callbacks if 'ModelCheckpoint' in str(type(c))][0]
cb.best_model_path

In [None]:
best = ModelLightning.load_from_checkpoint(cb.best_model_path)

In [None]:
result = trainer.predict(best, dataloaders=pred_dl)

In [None]:
probs = result[0]
for x in result[1:]:
    probs = torch.concat((probs, x), dim=0)
probs.shape

In [None]:
df = pd.DataFrame(data=probs.detach().cpu().numpy(), columns=train_ds.labels, index=pred_ds.keys)
df.index.name = 'region_id'
df

In [None]:
df.to_csv('/data/estorrs/multiplex_data/analysis/dcis_region_analysis/classifier_v4/results/probs.txt', sep='\t')