In [None]:
import os
import re
import uuid
from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd
import scanpy as sc
import tifffile
from PIL import Image, ImageOps
from einops import rearrange
import matplotlib.pyplot as plt

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import torch
from torchvision import datasets, transforms

In [None]:
from violet.utils.dataloaders import listfiles

In [None]:
fps = sorted(listfiles('/data/violet/sandbox/tcia_pda_run1/st/normalized/'))
len(fps)

In [None]:
# only 1x and 4x
fps = [fp for fp in fps if '1.jpeg' in fp or '4.jpeg' in fp]
len(fps)

In [None]:
import timm 
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

###### data loaders

In [None]:
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.data.auto_augment import RandAugment, rand_augment_ops

model = timm.create_model('efficientnetv2_s', num_classes=0)
config = resolve_data_config({}, model=model)
config

In [None]:
train_dir = '/data/violet/sandbox/tcia_pda_run1/st/train_v2'
val_dir = '/data/violet/sandbox/tcia_pda_run1/st/val_v2'

In [None]:
_RAND_TRANSFORMS = [
    'AutoContrast',
    'Equalize',
    'Invert',
    'Rotate',
    'Posterize',
    'Solarize',
    'SolarizeAdd',
    'Color',
    'Contrast',
    'Brightness',
    'Sharpness',
#     'ShearX',
#     'ShearY',
#     'TranslateXRel',
#     'TranslateYRel',
    #'Cutout'  # NOTE I've implement this as random erasing separately
]

In [None]:
def get_training_transform(resize=(288, 288)):
    return transforms.Compose((
        transforms.Resize(resize),
        RandAugment(rand_augment_ops(transforms=_RAND_TRANSFORMS), num_layers=2),
        transforms.ToTensor(),
        transforms.Normalize((0.76806694, 0.47375619, 0.58864233), (0.17746654, 0.21851493, 0.18837758))
    ))

def get_val_transform(resize=(288, 288)):
    return transforms.Compose((
        transforms.Resize(resize),
        transforms.ToTensor(),
        transforms.Normalize((0.76806694, 0.47375619, 0.58864233), (0.17746654, 0.21851493, 0.18837758))
    ))

def get_filepath_map(root_dir, resolution=('1', '4'), regex=r'.jpeg$'):
    fps = list(listfiles(root_dir, regex=regex))
    fps = [fp for fp in fps if fp.split('.')[-2][-1] in resolution]
    d = {}
    for fp in fps:
        fname = fp.split('/')[-1].split('.')[0]
        sample = re.sub(r'^(.*)_[0-9]+$', r'\1', fname)
        res = re.sub(r'^.*_([0-9]+)$', r'\1', fname)
        
        if sample not in d:
            d[sample] = {}
        d[sample][res] = fp
    d = {k:v for k, v in d.items() if len(v)==len(resolution)}
    return d
        

class MultiresDataset(Dataset):
    def __init__(self, root_dir, transform=None, val_transform=None,
                 resolution=('1', '4'), img_regex=r'.jpeg$', transform_prob=.8):
        self.root_dir = root_dir
        self.resolution = resolution
        self.transform = get_training_transform() if transform is None else transform
        self.transform_prob = transform_prob
        self.val_transform = get_val_transform() if val_transform is None else val_transform
        
        self.filepath_map = get_filepath_map(self.root_dir, resolution=resolution, regex=img_regex)
        self.samples = list(self.filepath_map.keys())
        target_df = pd.read_csv(os.path.join(root_dir, 'targets.txt'), sep='\t', index_col=0)
        #normalize between 0-1
        target_df = pd.DataFrame(data=target_df.values / np.max(target_df.values, axis=0),
                                 columns=target_df.columns, index=target_df.index)
        self.target_df = target_df.loc[self.samples]
        self.target_labels = self.target_df.columns

    def __len__(self):
        return len(self.filepath_map)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        d = self.filepath_map[sample]
        img_dict = {f'{k}x':Image.open(fp) for k, fp in d.items()}
        img_dict = {k:self.transform(img) if np.random.choice(
                    [True, False], size=1, p=[self.transform_prob, 1-self.transform_prob]) else self.val_transform(img)
                    for k, img in img_dict.items()}
        return {
            'sample': sample,
            'images': img_dict,
            'targets': self.target_df.values[idx]
        }
    
class MultiresPredictionDataset(Dataset):
    def __init__(self, root_dir, transform=None, resolution=('1', '4'), img_regex=r'.jpeg$'):
        self.root_dir = root_dir
        self.resolution = resolution
        self.transform = get_training_transform() if transform is None else transform
        
        self.filepath_map = get_filepath_map(self.root_dir, resolution=resolution, regex=img_regex)
        self.samples = list(self.filepath_map.keys())

    def __len__(self):
        return len(self.filepath_map)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        d = self.filepath_map[sample]
        img_dict = {f'{k}x':Image.open(fp) for k, fp in d.items()}
        img_dict = {k:self.transform(img) for k, img in img_dict.items()}
        return {
            'sample': sample,
            'images': img_dict,
        }

In [None]:
train_transform, val_transform = get_training_transform(), get_val_transform()
train_ds = MultiresDataset(train_dir, transform=train_transform)
val_ds = MultiresDataset(val_dir, transform=val_transform)

In [None]:
len(train_ds), len(val_ds)

In [None]:
# ls = [len(d['images']) for d in train_ds]
# from collections import Counter
# Counter(ls).most_common()


In [None]:
train_ds[0]

In [None]:
batch_size = 16
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
)
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
b = next(iter(train_dl))
b.keys()

In [None]:
b['images']['1x'].shape

In [None]:
from skimage.exposure import rescale_intensity
def display_tensor(x):
    x = rearrange(x.numpy(), 'c h w -> h w c')
    x = rescale_intensity(x, out_range=(0., 1.))
    plt.imshow(x)
    plt.show()

In [None]:
for i in range(10):
    display_tensor(b['images']['1x'][i])
    display_tensor(b['images']['4x'][i])

In [None]:
b = next(iter(val_dl))
for i in range(10):
    display_tensor(b['images']['1x'][i])
    display_tensor(b['images']['4x'][i])

In [None]:
class MultiresRegressor(torch.nn.Module):
    def __init__(self, n_out, n_in=1280*2, h=516):
        super(MultiresRegressor, self).__init__()
        
        self.stem_local = timm.create_model('efficientnetv2_s', num_classes=0)
        self.stem_global = timm.create_model('efficientnetv2_s', num_classes=0)
        
        self.linear1 = torch.nn.Sequential(
            torch.nn.Linear(n_in, h),
            torch.nn.ReLU(),
            torch.nn.Dropout()
        )
        self.linear2 = torch.nn.Sequential(
            torch.nn.Linear(h, h),
            torch.nn.ReLU(),
        )
        self.final = torch.nn.Sequential(
            torch.nn.Linear(h, n_out),
            torch.nn.Sigmoid(),
        )

    def forward(self, x_local, x_global):
        local_out = self.stem_local(x_local)
        global_out = self.stem_global(x_global)
        x = torch.cat((local_out, global_out), dim=1)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.final(x)
        return x

In [None]:
torch.cuda.empty_cache()
del model

In [None]:
model = MultiresRegressor(len(train_ds.target_labels))
model = model.cuda()
lr = 5e-4
epochs = 10
# opt = torch.optim.SGD(model.parameters(), lr=1e-4)
opt = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    opt, max_lr=lr,
    steps_per_epoch=len(train_dl), epochs=epochs)
#RMSProp optimizer with decay 0.9

In [None]:
import seaborn as sns
def plot_performance(train_se, val_se):
    if isinstance(train_se, torch.Tensor):
        train_se, val_se = train_se.detach().cpu().numpy(), val_se.detach().cpu().numpy()
        
    x = np.vstack((train_se, val_se))
    df = pd.DataFrame(data=x, columns=train_ds.target_labels, index=['train', 'val'])
    sns.heatmap(df)
    plt.show()
    

In [None]:
# plot_performance(np.random.rand(len(train_ds.target_labels)),
#                 np.random.rand(len(train_ds.target_labels)))

In [None]:
import time
criteria = torch.nn.MSELoss()
for epoch in range(epochs):
    train_loss, val_loss = 0., 0.
    start = time.time()
    train_se = torch.zeros(len(train_dl), len(train_ds.target_labels))
    val_se = torch.zeros(len(val_dl), len(val_ds.target_labels))
    model.train()
    for i, b in enumerate(train_dl):
        
        x_local, x_global, y = b['images']['1x'].cuda(), b['images']['4x'].cuda(), b['targets'].cuda()
        y = y.type(torch.float32)
        logits = model(x_local, x_global)
        loss = criteria(logits, y)
        train_se[i] = torch.sum(torch.square(logits - y), dim=0)
        opt.zero_grad()
        loss.backward()
        opt.step()

        train_loss += loss
        scheduler.step()
    time_delta = time.time() - start

    model.eval()
    with torch.no_grad():
        for i, b in enumerate(val_dl):
            x_local, x_global, y = b['images']['1x'].cuda(), b['images']['4x'].cuda(), b['targets'].cuda()
            y = y.type(torch.float32)
            logits = model(x_local, x_global)
            loss = criteria(logits, y)
            val_se[i] = torch.sum(torch.square(logits - y), dim=0)

            val_loss += loss

    train_loss /= len(train_dl)
    val_loss /= len(val_dl)
    plot_performance(torch.mean(train_se, dim=0), torch.mean(val_se, dim=0))
    e_lr = opt.param_groups[0]['lr']
    print(f'epoch: {epoch}, train loss: {train_loss}, val loss: {val_loss}, time: {time_delta}, lr: {e_lr}')

    

In [None]:
model_dir = '/data/violet/sandbox/tcia_pda_run1/st/models'
Path(model_dir).mkdir(parents=True, exist_ok=True)

In [None]:
torch.save(model.state_dict(), os.path.join(model_dir, 'st_10ep'))

In [None]:
model = MultiresRegressor(len(train_ds.target_labels))
model = model.cuda()
checkpoint = torch.load(os.path.join(model_dir, 'st_10ep'))
model.load_state_dict(checkpoint)

In [None]:
a_ds = MultiresDataset(train_dir, transform=val_transform)

In [None]:
a_dl = DataLoader(
    a_ds,
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
model.eval()
pred_df = None
with torch.no_grad():
    for i, b in enumerate(a_dl):
        x_local, x_global, y = b['images']['1x'].cuda(), b['images']['4x'].cuda(), b['targets'].cuda()
        y = y.type(torch.float32)
        logits = model(x_local, x_global)
        df = pd.DataFrame(data=logits.detach().cpu().numpy(), columns=train_ds.target_labels,
                          index=b['sample'])
        if pred_df is None:
            pred_df = df
        else:
            pred_df = pd.concat((pred_df, df), axis=0)
        
    for i, b in enumerate(val_dl):
        x_local, x_global, y = b['images']['1x'].cuda(), b['images']['4x'].cuda(), b['targets'].cuda()
        y = y.type(torch.float32)
        logits = model(x_local, x_global)
        df = pd.DataFrame(data=logits.detach().cpu().numpy(), columns=train_ds.target_labels,
                          index=b['sample'])
        if pred_df is None:
            pred_df = df
        else:
            pred_df = pd.concat((pred_df, df), axis=0)

In [None]:
pred_df

In [None]:
s = 'HT270P1_S1H1Fs5U1'
a = sc.read_visium(f'/data/spatial_transcriptomics/spaceranger_outputs/pancreatic/HT270P1-S1H1Fs5U1Bp1/')
a.obs.index = [f'{s}_{x}' for x in a.obs.index]
a

In [None]:
pred_df.loc[[x for x in a.obs.index.to_list() if x in pred_df.index]]

In [None]:
a.obs

In [None]:
for gene in pred_df.columns:
    a.obs[f'predicted_{gene}'] = [pred_df.loc[x, gene] if x in pred_df.index else 0. for x in a.obs.index]

In [None]:
a.obs

In [None]:
sc.pl.spatial(a, color='predicted_PTPRC', alpha_img=0.)

In [None]:
sc.pl.spatial(a, color='predicted_CD8A', alpha_img=0.)

In [None]:
sc.pl.spatial(a, color='predicted_EPCAM', alpha_img=0.)

In [None]:
sc.pl.spatial(a, color='predicted_BGN', alpha_img=0.)

In [None]:
sc.pl.spatial(a, color='predicted_PRSS1', alpha_img=0.)

In [None]:
s = 'HT264P1_S1H2Fs1_U1'
a = sc.read_visium(f'/data/spatial_transcriptomics/spaceranger_outputs/pancreatic/HT264P1-S1H2Fs1U1Bp1/')
a.obs.index = [f'{s}_{x}' for x in a.obs.index]

for gene in pred_df.columns:
    a.obs[f'predicted_{gene}'] = [pred_df.loc[x, gene] if x in pred_df.index else 0. for x in a.obs.index]
a

In [None]:
sc.pp.log1p(a)

In [None]:
sc.pl.spatial(a, color=['PTPRC', 'predicted_PTPRC',
                       'EPCAM', 'predicted_EPCAM',
                       'BGN', 'predicted_BGN'], alpha_img=0., ncols=2)

In [None]:
sc.pl.spatial(a, color=['EPCAM', 'predicted_EPCAM'], alpha_img=0.)

In [None]:
sc.pl.spatial(a, color=['BGN', 'predicted_BGN'], alpha_img=0.)

In [None]:
sc.pl.spatial(a, color='PTPRC', alpha_img=0.)

In [None]:
fps = list(listfiles('/data/violet/sandbox/tcia_pda_run1/tcia/raw/C3L-00017-21/', regex=r'.jpeg$'))
len(fps), fps[:2]

In [None]:
a_ds = MultiresPredictionDataset('/data/violet/sandbox/tcia_pda_run1/tcia/raw/C3L-00017-21/', transform=val_transform)

In [None]:
a_ds = MultiresPredictionDataset('/data/violet/sandbox/tcia_pda_run1/tcia/raw/C3L-00401-22/', transform=val_transform)

In [None]:
a_dl = DataLoader(
    a_ds,
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
len(a_ds), a_ds.samples

In [None]:
model.eval()
pred_df = None
with torch.no_grad():
    for i, b in enumerate(a_dl):
        x_local, x_global = b['images']['1x'].cuda(), b['images']['4x'].cuda()
        y = y.type(torch.float32)
        logits = model(x_local, x_global)
        df = pd.DataFrame(data=logits.detach().cpu().numpy(), columns=train_ds.target_labels,
                          index=b['sample'])
        if pred_df is None:
            pred_df = df
        else:
            pred_df = pd.concat((pred_df, df), axis=0)

In [None]:
pred_df

In [None]:
from violet.utils.analysis import display_2d_scatter

In [None]:
for h in pred_df.columns:
    print(h)
    display_2d_scatter(pred_df, h)
    plt.show()

In [None]:
display_2d_scatter(pred_df, 'PTPRC')

In [None]:
display_2d_scatter(pred_df, 'EPCAM')

In [None]:
display_2d_scatter(pred_df, 'BGN')

In [None]:
region = [(-115, -90), (60, 85)]
for h in pred_df.columns:
    print(h)
    display_2d_scatter(pred_df, h, region=region, scale=.015)
    plt.show()

In [None]:
model = timm.create_model('efficientnetv2_s', num_classes=0)

In [None]:
x = torch.randn(1, 3, 288, 288)
x = model(x)

In [None]:
x.shape

In [None]:
x = torch.randn(1, 3, 288, 288)
x = model.forward_features(x)

In [None]:
x.shape

In [None]:
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

config = resolve_data_config({}, model=model)

In [None]:
config

In [None]:
class MultiresDataset(Dataset):
    def __init__(self, fps, transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample
