# Test tiling

In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd

import torchvision
import pytorch_lightning as pl
import lightly
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize

from tissue_purifier.util_data import *

### Load the data

In [12]:
root = "/home/jupyter/data/slide-seq/original_data/"
root = "/Users/ldalessi/REPOS/ML_for_slideseq/data/"

df_wt1 = pd.read_csv(root + "wt1.csv")
df_wt2 = pd.read_csv(root + "wt2.csv")
df_wt3 = pd.read_csv(root + "wt3.csv")
df_dis1 = pd.read_csv(root + "dis1.csv")
df_dis2 = pd.read_csv(root + "dis2.csv")
df_dis3 = pd.read_csv(root + "dis3.csv")

### Split into left and right

In [13]:
df_wt1_left = df_wt1[df_wt1["x"] < df_wt1["x"].median()]
df_wt1_right = df_wt1[df_wt1["x"] >= df_wt1["x"].median()]

df_wt2_left = df_wt2[df_wt2["x"] < df_wt2["x"].median()]
df_wt2_right = df_wt2[df_wt2["x"] >= df_wt2["x"].median()]

df_wt3_left = df_wt3[df_wt3["x"] < df_wt3["x"].median()]
df_wt3_right = df_wt3[df_wt3["x"] >= df_wt3["x"].median()]

df_dis1_left = df_dis1[df_dis1["x"] < df_dis1["x"].median()]
df_dis1_right = df_dis1[df_dis1["x"] >= df_dis1["x"].median()]

df_dis2_left = df_dis2[df_dis2["x"] < df_dis2["x"].median()]
df_dis2_right = df_dis2[df_dis2["x"] >= df_dis2["x"].median()]

df_dis3_left = df_dis3[df_dis3["x"] < df_dis3["x"].median()]
df_dis3_right = df_dis3[df_dis3["x"] >= df_dis3["x"].median()]

labels_left = ["wt", "wt", "wt", "dis", "dis", "dis"]
labels_right = ["wt", "wt", "wt", "dis", "dis", "dis"]
filename_left = ["wt1", "wt2", "wt3", "dis1", "dis2", "dis3"]
filename_right = ["wt1", "wt2", "wt3", "dis1", "dis2", "dis3"]                
all_df_left = [df_wt1_left, df_wt2_left, df_wt3_left, df_dis1_left, df_dis2_left, df_dis3_left]
all_df_right = [df_wt1_left, df_wt2_left, df_wt3_left, df_dis1_left, df_dis2_left, df_dis3_left]

### Define the defaults values

In [14]:
num_workers = 8
batch_size = 128
seed = 1
max_epochs = 100
input_size = 224
num_ftrs = 32
pixel_size = 4.0
crop_size = input_size
input_channels = 9
n_element_min = 200

### TrainTransform

In [15]:
train_transform = torchvision.transforms.Compose([
    DropoutSparseTensor(dropout=(0.0, 0.4)),
    StackTensor(dim=-4),
    RandomGaussianBlur(sigma=(1.0, 1.0)),
    torchvision.transforms.RandomAffine(degrees=180, 
                                        scale=(0.75, 1.25), 
                                        shear=0.0, 
                                        interpolation=torchvision.transforms.InterpolationMode.NEAREST, 
                                        fill=0),
    RandomIntensity(factor=(0.7, 1.3)),
    torchvision.transforms.CenterCrop(size=crop_size),
    torchvision.transforms.Resize(input_size)
])

### Left TrainDataset and DataLoader

In [16]:
sparse_images_left = [
    SparseImage.from_panda(
    df, x="x", y="y", category="max_cell_type", 
    pixel_size=4.0, padding=10) for df in all_df_left
]

#n_crops_for_tissue = int(numpy.ceil(float(batch_size) / len(sparse_images_left)))
n_crops_for_tissue = 2

dataset_train_left = SparseDataset(x=sparse_images_left, 
                                   y=labels_left,
                                   z=filename_left,
                                   transform_x=RandomCropSparseTensor(n_crops=n_crops_for_tissue, 
                                                                      crop_size=int(1.5*crop_size),
                                                                      n_element_min=100),
                                   transform_y=Interleave(n_repeat=n_crops_for_tissue)) 

dataloader_train_left = DataLoaderWithLoad(
    dataset_train_left,
    batch_size=dataset_train_left.__len__(),
    collate_fn=SpecialCollateFn(transform=train_transform, simclr_output=False),
    shuffle=False)

number of elements ---> 15829
The dense shape of the image is -> torch.Size([9, 611, 1164])
number of elements ---> 16529
The dense shape of the image is -> torch.Size([9, 602, 845])
number of elements ---> 19603
The dense shape of the image is -> torch.Size([9, 602, 1170])
number of elements ---> 13597
The dense shape of the image is -> torch.Size([9, 601, 1170])
number of elements ---> 21387
The dense shape of the image is -> torch.Size([9, 596, 1170])
number of elements ---> 16718
The dense shape of the image is -> torch.Size([9, 523, 1150])


In [18]:
for i in range(dataset_train_left.__len__()):
    a,b,c = dataset_train_left.__getitem__(i)
    print(b,c)
dataset_train_left.labels_to_code

[1, 1] ['wt1', 'wt1']
[1, 1] ['wt2', 'wt2']
[1, 1] ['wt3', 'wt3']
[0, 0] ['dis1', 'dis1']
[0, 0] ['dis2', 'dis2']
[0, 0] ['dis3', 'dis3']


{'dis': 0, 'wt': 1}

In [25]:
imgs, labels, fname = next(iter(dataloader_train_left))
print(imgs.shape, labels.shape, len(fname))
show_tensor(imgs[:10, 0], figsize=(12,4), n_col=5, cmap='hot', normalize_range=(0.0, 1.0))
print(labels)

torch.Size([12, 9, 224, 224]) torch.Size([12]) 12
tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0])


### Take the pretrained backbone and add a linear classifier on top

In [34]:
#class Classifier(torch.nn.Module):
#    def __init__(self, backbone: torch.nn.Module, n_classes: int):
#        super().__init__()
#        self.n_classes = n_classes
#        self.backbone = backbone
#        self.prediction_head = torch.nn.Linear(in_features=128, out_features=n_classes, bias=True)
#        
#    def forward(self, x: torch.Tensor):
#        with torch.no_grad():
#            y = self.backbone(x)
#        return self.prediction_head(y)
#    
## The backbone is the one trained with Contrastive Learning (here I am using resnet18)
#resnet18 = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
#resnet18.conv1 = torch.nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#resnet18.fc = torch.nn.Linear(in_features=512, out_features=128, bias=True)


class RandomClassifier(torch.nn.Module):
    def __init__(self, n_classes: int):
        super().__init__()
        self.prediction_head = torch.nn.Linear(in_features=12, out_features=n_classes, bias=True)

    def forward(self, x: torch.Tensor):
        with torch.no_grad():
            batch_size = x.shape[0]
            y = torch.randn((batch_size, 12), dtype=x.dtype, device=x.device)
        return self.prediction_head(y)

dummy_classifier = RandomClassifier(n_classes=2)

# train loop

In [None]:
from tissue_purifier.loss import NoisyLoss

classifier = Classifier(backbone=resnet18, n_classes=2)
optimizer = optimizer = torch.optim.Adam(classifier.parameters(), lr=1E-3, betas=(0.9, 0.999))
criterion = NoisyLoss()

max_epoch = 2
for epoch in range(max_epoch):
    for counter, (imgs, labels, fnames) in enumerate(dataloader_train_left):
        #print(imgs.shape, labels.shape, fnames)
        
        outputs = classifier(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()