# Examine Encoder

Note:

```python
layer_labels=[1,2],
instrument_labels=[3,4]
```

## Step 1: Setup Environment

In [None]:
from argparse import Namespace
import time
import sys
import pprint
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from datasets import augmentations
from utils.common import tensor2im, log_input_image
from models.psp import pSp

# added imports
import os
import imageio
import matplotlib.pyplot as plt
from configs.transforms_config import SegToImageTransforms
from glob import glob


%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.chdir('/home/extra/micheal/pixel2style2pixel')

## Step 2: Define Inference Parameters

In [None]:
transform_opt = Namespace(label_nc=5, output_nc=1)
transform_dict = SegToImageTransforms(transform_opt).get_transforms()

pretrained_weight_path = '/home/extra/micheal/pixel2style2pixel/pretrained_models/psp_celebs_seg_to_face.pt'
checkpoint_path = '/home/extra/micheal/pixel2style2pixel/experiments/ioct_seg2bscan1/checkpoints/iteration_10000.pt'
label_paths = '/home/extra/micheal/pixel2style2pixel/data/ioct/labels/train/*'
test_image_path = glob(label_paths)[863]

EXPERIMENT_ARGS = {
    'model_path': checkpoint_path,
    'image_path': test_image_path,
    'transform': transform_dict['transform_inference']
}
assert os.path.getsize(EXPERIMENT_ARGS['model_path']) > 1000000, 'the image file is not complete'

## Test 1: Load Original Model

In [None]:
model_path = EXPERIMENT_ARGS['model_path']
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
optss = Namespace(**opts)
optss.input_nc

In [None]:
optss.output_nc = 1
optss.label_nc = 5
optss.input_nc = 5
# optss.checkpoint_path = None

In [None]:
net_test = pSp(optss)
# net_test.latent_avg = torch.randn((18, 512))
net_test.load_weights()

In [None]:
input_tensor = torch.randn((1, 5, 256, 256))
device = torch.device('cuda:1')
net_test = net_test.to(device)
net_test.latent_avg = net_test.latent_avg.to(device)
input_tensor = input_tensor.to(device)
with torch.no_grad():
    result_batch, latents = net_test(input_tensor.float(), randomize_noise=False, return_latents=True)
print('result_shape', result_batch.shape)

In [None]:
latents.shape

## Test 2: Dataloader

In [None]:
from options.train_options import TrainOptions
from training.coach import Coach
import json


# if os.path.exists(opts.exp_dir):
#     raise Exception('Oops... {} already exists'.format(opts.exp_dir))
# os.makedirs(opts.exp_dir)

opts_dict = vars(optss)
# pprint.pprint(opts_dict)
# with open(os.path.join(optss.exp_dir, 'opt.json'), 'w') as f:
#     json.dump(opts_dict, f, indent=4, sort_keys=True)

In [None]:
optss.load_partial_weights = False

In [None]:
optss.device = torch.device('cuda:1')

In [None]:
from training.coach import Coach
coach = Coach(optss)

In [None]:
coach.net.load_weights()

In [None]:
train_loader = iter(coach.train_dataloader)

In [None]:
# test
batch = next(train_loader)
test_label = batch[0].cuda().float()
print('label shape', test_label.shape)
with torch.no_grad():
    test_bscan = coach.net(test_label)
    print('test_bscan shape', test_bscan.shape)

In [None]:
def tensor2im(var, grayscale=False):
    var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
    var = ((var + 1) / 2)
    var[var < 0] = 0
    var[var > 1] = 1
    var = var * 255
    if grayscale:
        im = Image.fromarray(var.astype('uint8').squeeze(axis=2), 'L')
    else:
        im = Image.fromarray(var.astype('uint8'))
    return im

tensor2im(test_bscan[6], grayscale=True)

In [None]:
batch[0].shape

In [None]:
batch[1].shape

In [None]:
%matplotlib inline

plt.imshow(batch[1][0])

## Test 3: Play with latent code

### 1. Locate the instrument in the image

In [None]:
from glob import glob

label_paths = 'data/ioct/labels/train/*'

label_path = glob(label_paths)[673]
bscan_path = label_path.split('labels')[0] + 'bscans' + label_path.split('labels')[1]
label = imageio.imread(label_path)
bscan = imageio.imread(bscan_path)
label = label
aggragated = np.concatenate((label*50, bscan), axis=1)
plt.imshow(aggragated)

In [None]:
def get_coor_avg(label_map, label_num):
    coords = np.argwhere(label_map==label_num)
    x_avg = np.average(coords[:, 0])
    y_avg = np.average(coords[:, 1])
    n_label = len(coords)
    return (x_avg, y_avg, n_label)

# label[int(x_avg)-5:int(x_avg)+5, int(y_avg)-5:int(y_avg)+5] = 5
# plt.imshow(label)

In [None]:
from tqdm import tqdm
coord_list = []
for label_path in tqdm(glob(label_paths)):
    label = imageio.imread(label_path)
    if len(coord_list) == 0:
        print(label.shape)
    x4, y4, n4 = get_coor_avg(label, 4)
    x2, y2, n2 = get_coor_avg(label, 2)
    coord_list.append({
        'l4': (x4, y4), 
        'n4': n4,
        'l2': (x2, y2),
        'n2': n2,
        'path': label_path
    })

In [None]:
len(coord_list)

### 2. Sort based on l4.x

In [None]:
import copy
def key_x4(item):
    return item['l4'][0]

l4x = [e for e in coord_list if not np.isnan(e['l4'][0]) and not np.isnan(e['l4'][1]) and e['n4'] > 100]
print(len(l4x))
l4x.sort(key=key_x4)

In [None]:
for i in range(20):
    print(l4x[i]['l4'][0])

### 3. Get style vector of each

#### 3.1 Load Trained Model

## Step 3: Load Trained Model

In [None]:
model_path = EXPERIMENT_ARGS['model_path']
ckpt = torch.load(model_path, map_location='cpu')

opts = ckpt['opts']
pprint.pprint(opts)

In [None]:
# update the training options
opts['checkpoint_path'] = model_path

In [None]:
optss = Namespace(**opts)
net = pSp(optss)
net.load_weights()
net.eval()
print('Model successfully loaded!')

In [None]:
net.latent_avg.shape

In [None]:
net

In [None]:
im = Image.open('/home/extra/micheal/pixel2style2pixel/data/ioct/bscans/test/0c3839cd-0aa9-4e6e-bd4e-eb8f0520e2056578-012.png')

In [None]:
np.array(im.convert('L')).shape

## Step 4: Visualize Input

In [None]:
image_path = EXPERIMENT_ARGS["image_path"]
original_image = Image.open(image_path)
# if opts.label_nc == 0:
#     original_image = original_image.convert("RGB")
# else:
#     original_image = original_image.convert("L")

In [None]:
plt.imshow(original_image)

In [None]:
coach.train_dataset[0]

## Step 5: Feed to Encoder

In [None]:
img_transforms = EXPERIMENT_ARGS['transform']
transformed_image = img_transforms(original_image)
print("Transformed segmentation is of shape:", transformed_image.shape)

In [None]:
def run_on_batch(inputs, net, latent_mask=None):
    if latent_mask is None:
        result_batch, latents = net(inputs.float(), randomize_noise=False, return_latents=True)
    else:
        result_batch = []
        latents = []
        for image_idx, input_image in enumerate(inputs):
            # get latent vector to inject into our input image
            vec_to_inject = np.random.randn(1, 512).astype('float32')
            _, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"),
                                      input_code=True,
                                      return_latents=True)
            # get output image with injected style vector
            res, latent = net(input_image.unsqueeze(0).to("cuda").float(),
                      latent_mask=latent_mask,
                      inject_latent=latent_to_inject,
                      return_latents=True)
            result_batch.append(res)
            latents.append(latent)
        result_batch = torch.cat(result_batch, dim=0)
        latents = torch.cat(latents, dim=0)
    return result_batch, latents

In [None]:
device = torch.device('cuda:1')
net = net_test
net.load_weights()
net.latent_avg = net.latent_avg.to(device)
net = net.to(device)
transformed_image = transformed_image.to(device)

In [None]:
with torch.no_grad():
    tic = time.time()
    result_image, latents = run_on_batch(transformed_image.unsqueeze(0), net, latent_mask=None)
    toc = time.time()
    print('Inference took {:.4f} seconds.'.format(toc - tic))

In [None]:
print("images shape", result_image.shape)
print("latents shape", latents.shape)

In [None]:
def tensor2im(var, grayscale=False):
    var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
    var = ((var + 1) / 2)
    var[var < 0] = 0
    var[var > 1] = 1
    var = var * 255
    if grayscale:
        im = Image.fromarray(var.astype('uint8').squeeze(axis=2), 'L')
    else:
        im = Image.fromarray(var.astype('uint8'))
    return im

In [None]:
output_image = tensor2im(result_image[0], grayscale=True)
output_image

In [None]:
np.unique(output_image)