## Step 0: Environment

In [None]:
import os
os.chdir('/home/extra/micheal/pixel2style2pixel')

In [None]:
from argparse import Namespace
import time
import sys
import pprint
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from utils.common import tensor2im, log_input_image
from models.psp import pSp

# added imports
import os
import imageio
import matplotlib.pyplot as plt
from configs.transforms_config import SegToImageTransforms
from glob import glob
from training.coach import Coach


%matplotlib inline
%load_ext autoreload
%autoreload 2

## Step 2: Load Trained Model

In [None]:
model_path = '/home/extra/micheal/pixel2style2pixel/experiments/ioct_seg2bscan2/checkpoints/best_model.pt'

In [None]:
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
optss = Namespace(**opts)
optss.batch_size = 1
optss.stylegan_weights = model_path
optss.load_partial_weights = True

coach = Coach(optss)

device = torch.device(coach.opts.device)

In [None]:
coach.net = coach.net.eval()

In [None]:
train_loader = iter(coach.train_dataloader)

In [None]:
batch = next(train_loader)
test_label = batch[0].cuda().float()
print('label shape', test_label.shape)
with torch.no_grad():
    test_bscan = coach.net(test_label)
    print('test_bscan shape', test_bscan.shape)

    
pred = test_bscan[0][0].cpu().detach()
bscan = batch[1][0][0]
label = np.argmax(batch[0][0], axis=0)
    
fig, axes = plt.subplots(1,3, figsize=(15,5))
# axes[0].axis('off')
# axes[1].axis('off')
# axes[2].axis('off')
axes[0].imshow(pred)
axes[0].set_xlabel('pred')
axes[1].imshow(bscan)
axes[1].set_xlabel('bscan')
axes[2].imshow(label)
axes[2].set_xlabel('label')

In [None]:
with torch.no_grad():
    test_bscan, latent = coach.net(test_label, return_latents=True)
    bscan_a, latent_a = coach.net(latent, input_code=True, return_latents=True)
torch.all(latent==latent_a)

### Debug: Compare the difference of manual loading

Manually load image

In [None]:
label_paths = 'data/ioct/labels/train/*'

label_path = glob(label_paths)[673]
bscan_path = label_path.split('labels')[0] + 'bscans' + label_path.split('labels')[1]
label = imageio.imread(label_path)
bscan = imageio.imread(bscan_path)
label = label
aggragated = np.concatenate((label*50, bscan), axis=1)
plt.imshow(aggragated)

Load with image library:

 Load transforms

In [None]:
transform_opt = Namespace(label_nc=5, output_nc=1)
transform_dict = SegToImageTransforms(transform_opt).get_transforms()
img_transforms = transform_dict['transform_inference']

Convert use python codes in files

In [None]:
label = Image.open(label_path).convert('L')
transformed_label = img_transforms(label)
print('transformed label is of shape', transformed_label.shape)
plt.imshow(np.argmax(transformed_label, axis=0))

Manually feed to the net work we load before:

In [None]:
import cv2

batched_label = transformed_label.unsqueeze(0).cuda().float()
print('transformed label shape', batched_label.shape)
with torch.no_grad():
    pred_bscan = coach.net(batched_label)
    print('test_bscan shape', pred_bscan.shape)

pred_bscan = pred_bscan[0][0].cpu().detach()
    
fig, axes = plt.subplots(1,3, figsize=(15,5))
axes[0].axis('off')
axes[1].axis('off')
axes[2].axis('off')
axes[0].imshow(pred_bscan)
axes[0].set_xlabel('pred')
axes[1].imshow(cv2.resize(bscan, dsize=(256, 256)))
axes[1].set_xlabel('bscan')
axes[2].imshow(label.resize((256, 256)))
axes[2].set_xlabel('label')

#### Custom functions fot transformation

In [None]:
import torchvision.transforms as transforms

class Conver2Uint8(torch.nn.Module):
    '''
    Resize input when the target dim is not divisible by the input dim
    '''
    def __init__(self):
        super().__init__()

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be scaled.

        Returns:
            PIL Image or Tensor: Rescaled image.
        """
        img = torch.round(img * 255)
        return img
    
class MyResize(torch.nn.Module):
    '''
    Resize input when the target dim is not divisible by the input dim
    '''
    def __init__(self, size):
        super().__init__()
        self.size = size

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be scaled.

        Returns:
            PIL Image or Tensor: Rescaled image.
        """
        h, w = img.shape[-2], img.shape[-1]
        target_h, target_w = self.size
        assert h % target_h == 0, f"target_h({target_h}) must be divisible by h({h})"
        assert w % target_w == 0, f"target_w({target_w}) must be divisible by w({w})"
        # Resize by assigning the max value of each pixel grid
        kernel_h = h // target_h
        kernel_w = w // target_w
        img_target = torch.nn.functional.max_pool2d(img, kernel_size=(kernel_h, kernel_w), stride=(kernel_h, kernel_w))
        return img_target
    
class ToOneHot(torch.nn.Module):
    '''
    Convert input to one-hot encoding
    '''
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

    def forward(self, img):
        """
        Args:
            img (Tensor): Image to be scaled of shape (1, h, w).

        Returns:
            Tensor: Rescaled image.
        """
        img = img.long()[0]
        img = torch.nn.functional.one_hot(img, num_classes=self.num_classes)
        img = img.permute(2, 0, 1)
        return img

resize = MyResize((256, 256))
off_resize = transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST)
my_transforms = transforms.Compose([ transforms.ToTensor(), Conver2Uint8(), resize, ToOneHot(5)])

In [None]:
pil_img = Image.open(label_path).convert('L')
transformed_img = img_transforms(pil_img)
reverted_label = np.argmax(transformed_img, axis=0)
transformed_img = my_transforms(pil_img)
# np_img = np.asarray(pil_img)
# print(np_img.shape)
# plt.imshow(np_img==1)
print(np.unique(transformed_img))
plt.imshow(np.argmax(transformed_img, axis=0) == 1)


In [None]:
print(transformed_img.shape)

Inference

In [None]:
# coach.net(img_transforms(label))

## Step 3: Locate instruments in the image

In [None]:
def get_coor_avg(label_map, label_num):
    assert len(label_map.shape)==2, f'label map must be a 2D array, but got shape {label_map.shape}'
    if isinstance(label_map, torch.Tensor):
        label_map = label_map.numpy()
    coords = np.argwhere(label_map==label_num)
    assert coords.shape[1] == len(label_map.shape), f'coords.shape[1] must equals ndim, but got shape {coords.shape}'
    x_avg = np.average(coords[:, 0])
    y_avg = np.average(coords[:, 1])
    n_label = len(coords)
    return (x_avg, y_avg, n_label)

In [None]:
from glob import glob

label_paths = 'data/ioct/labels/train/*'

label_path = glob(label_paths)[5843]
bscan_path = label_path.split('labels')[0] + 'bscans' + label_path.split('labels')[1]
label = imageio.imread(label_path) == 2
bscan = imageio.imread(bscan_path)
label = label
aggragated = np.concatenate((label*50, bscan), axis=1)
plt.imshow(aggragated)

In [None]:
from tqdm import tqdm
coord_list = []

for idx, (label, bscan) in enumerate(tqdm(coach.train_dataset)):
    label = label.argmax(dim=0)
    x4, y4, n4 = get_coor_avg(label, 4)
    x2, y2, n2 = get_coor_avg(label, 2)
    coord_list.append({
        'idx': idx,
        'l4': (x4, y4), 
        'n4': n4,
        'l2': (x2, y2),
        'n2': n2
    })

In [None]:
import pickle

with open(r"experiments/coords.pickle", "wb") as output_file:
    pickle.dump(coord_list, output_file)

In [None]:
import pickle
with open("experiments/coords.pickle", "rb") as f:
    coord_list = pickle.load(f)

## Step 4: Sort according to label2's x-coordinate

label 2 is the instrument

In [None]:
import pandas as pd

def key_x2(item):
    return item['l2'][0]

l2x = [e for e in coord_list if not np.isnan(e['l2'][0]) and not np.isnan(e['l2'][1]) and e['n2'] > 100]
print(f"{len(l2x)} / {len(coord_list)} has instrument inside")
l2x.sort(key=key_x2)

print('\nStatistics of x-axis:')
l2x_extracted = [e['l2'][0] for e in l2x]
df_l2x = pd.DataFrame(l2x_extracted)
df_l2x.describe()

## Step 5: Extract style latent of each

In [None]:
l2x_latents = []

for e in tqdm(l2x):
    label, bscan = coach.train_dataset[e['idx']]
    label = label.unsqueeze(0).float().to(device)
    with torch.no_grad():
        pred, latent = coach.net(label, return_latents=True)
    l2x_latents.append(latent.detach().cpu().numpy()[0])

print(f"length of l2x latent: {len(l2x_latents)}")
print(f"each latent is of shape: {l2x_latents[0].shape}")

In [None]:
with open(r"experiments/latents_l2x.pickle", "wb") as output_file:
    pickle.dump(l2x_latents, output_file)

In [None]:
with open(r"experiments/latents_l2x.pickle", "rb") as f:
    l2x_latents = pickle.load(f)

## Step 6: Find deviation of positive samples from average

### 1. Compute mean and standard deviation of all style vectors

In [None]:
l2x_latent_flat = np.array([e.reshape(-1) for e in l2x_latents])
l2x_latent_flat.shape

Mean & standard deviation:

`p` means population

In [None]:
mean_p = np.mean(l2x_latent_flat, axis=0)
print('mean is of shape', mean_p.shape)
std_p = np.std(l2x_latent_flat, axis=0)
print('standard deviation is of shape', std_p.shape)

### 2. Find positive examples

Set the right most 200 images as positive?

Here `e` means exempler

In [None]:
latents_e = l2x_latent_flat[-200:, :]
latents_e.shape

### 3. Compute the normalized difference of 

Normailized difference of each positive sample from population distribution

In [None]:
normalized_diff_e = (latents_e-mean_p)/std_p
normalized_diff_e.shape

### 4. Compute the mean and std of the normalized difference

In [None]:
mean_e = np.mean(normalized_diff_e, axis=0)
print('mean is of shape', mean_e.shape)
std_e = np.std(normalized_diff_e, axis=0)
print('standard deviation is of shape', std_e.shape)

### 5. Compute the impact factor

which is the magnitude of mean divided by the standard deviation

In [None]:
impact_e = np.absolute(mean_e) / std_e

Show statistics

In [None]:
impact_df = pd.DataFrame(impact_e)
impact_df.describe()

### 6. Sort and list the 10 most impactful latent location

In [None]:
sorted_impact = np.argsort(-impact_e)

In [None]:
idx_with_impact = [(sorted_impact[i], impact_e[sorted_impact[i]]) for i in range(100)]
print(idx_with_impact)

## Step 7: Try to manupulate a picture?

In [None]:
latent = torch.from_numpy(l2x_latent_flat[8].reshape(18, 512)).float().to(device)
latent.shape

In [None]:
with torch.no_grad():
    output, latent_temp = coach.net.decoder([latent.unsqueeze(0)], input_is_latent=True, randomize_noise=True)

In [None]:
plt.imshow(output.detach().cpu()[0][0])

In [None]:
latent = latent.reshape(-1)
latent[2133] -= 10
latent[1077] -= 10
latent = latent.reshape(18, 512)

In [None]:
with torch.no_grad():
    output, latent_temp = coach.net.decoder([latent.unsqueeze(0)], input_is_latent=True, randomize_noise=True)
plt.imshow(output.detach().cpu()[0][0])