# Latent to Latent

In [None]:
%matplotlib inline  
import os
os.chdir('/home/extra/micheal/pixel2style2pixel')

In [None]:
from argparse import Namespace
from tqdm import tqdm
import numpy as np
import torch
from training.coach import Coach
from models.regressor import Regressor
from models.latent2latent import Latent2Latent, LightningLatent2Latent
from utils.latent_utils import train_imgs_batch
from utils.regressor_utils import attribute_label_from_segmentation, get_train_loader_from_checkpoint
from utils.latent_utils import modify_attribute, get_latent

## 1. Compute a corelation matrix

Assuming $a'$ is the attribute vector changing a single attribute. To correct unfeasible attribute combinations, we first created a correlation matrix based on all the attributes in the training dataset. Using this data we create a corrected vector $a'_c$ by multiplying the corresponding row elements from the correlation matrix with the $a'$ if the elements are larger than a threshold.

### 1.1 Acquire attributes from training dataset

In [None]:
help(attribute_label_from_segmentation)
help(get_train_loader_from_checkpoint)
# train_loader = get_train_loader_from_checkpoint('/home/extra/micheal/pixel2style2pixel/experiments/ioct_seg2bscan2/checkpoints/best_model.pt')

In [None]:
all_attributes = []
for batch in tqdm(train_loader):
    seg, bscan = batch
    attributes = [attribute_label_from_segmentation(s) for s in seg]
    all_attributes += attributes
all_attributes = np.array(all_attributes)
print("attributes shape", all_attributes.shape)
all_attributes[:, 1] /= 1024
all_attributes[:, 2] /= 512
all_attributes[:, 3] /= 1024
all_attributes[:, 4] /= 512
all_attributes[:, 6] /= 1024
all_attributes[:, 7] /= 512
all_attributes[:, 8] /= 1024
all_attributes[:, 9] /= 512

Save for future use

In [None]:
with open("artifacts/objects/attributes_train.np", "wb") as f:
    np.save(f, all_attributes)

### 1.2 Compute and plot correlation matrix

In [None]:
all_attributes = all_attributes.transpose()
print("all attributes transposed shape", all_attributes.shape)
R1 = np.corrcoef(all_attributes)
print("correlation matrix shape", R1.shape)

In [None]:
import matplotlib.pyplot as plt

plt.matshow(R1)
cb = plt.colorbar()
cb.ax.tick_params(labelsize=14)
plt.title('Correlation Matrix', fontsize=16);

Save correlation matrix

In [None]:
with open("artifacts/objects/corrmat10.np", "wb") as f:
    np.save(f, R1)

### 1.3 Load directly if already saved

In [None]:
import numpy as np

with open("artifacts/objects/corrmat10.np", "rb") as f:
    R1 = np.load(f)

## 2. Define parameter correction method

Currently I decide to only experiment on the horizontal location of the instrument.

### 2.0 Simply import everything
Since they are moved to files, we simply import them

In [None]:
help(attribute_label_from_segmentation)
help(get_train_loader_from_checkpoint)
help(modify_attribute)
help(get_latent)

### 2.1 Define functions for modifying attributes

In [None]:
from utils.latent_utils import modify_attribute, get_latent
help(modify_attribute)
help(get_latent)

## 3. Train

### 3.1 Define functions for training a batch

In [None]:
help(Latent2Latent)
help(train_imgs_batch)

### 3.2 Prepare

Test：

- Load regressor:

In [None]:
import torch
regressor = Regressor()
regressor.load_state_dict(torch.load("artifacts/weights/regressor.pt"))

- Load stylegan model

In [None]:
model_path = '/home/extra/micheal/pixel2style2pixel/experiments/ioct_seg2bscan2/checkpoints/best_model.pt'

ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
optss = Namespace(**opts)
optss.batch_size = 16
optss.stylegan_weights = model_path
optss.load_partial_weights = True

coach = Coach(optss)

device = torch.device(coach.opts.device)

- define latent model

In [None]:
latent2latent = Latent2Latent().to(device)

Validation check for latent to latent
```
seg, bscan = coach.train_dataset[0]
print("seg shape", seg.shape)
pred, latent, codes = get_latent(coach.net, seg, device)
print(f"pred shape: {list(pred.size())}, latent shape: {list(latent.size())}, codes shape {list(codes.shape)}")
```

- Validity check

In [None]:
batch = next(iter(coach.train_dataloader))
segs, bscans = batch

device = torch.device('cuda')
segs = segs.to(device).float()
regressor = regressor.to(device)

with torch.no_grad():
    loss = train_imgs_batch(segs, coach.net, regressor, latent2latent, device, R1)
loss

### 3.2 Training loop

In [None]:
import torch.optim as optim
optimizer = optim.AdamW(latent2latent.parameters(), lr=0.01)

In [None]:
for p in coach.net.parameters():
    p.requires_grad = False
for p in regressor.parameters():
    p.requires_grad = False

In [None]:
for epoch in tqdm(range(10)):
    pbar = tqdm(coach.train_dataloader, position=0, leave=True)
    for segs, bscans in pbar:
        segs = segs.to(device).float()
        loss = train_imgs_batch(segs, coach.net, regressor, latent2latent, device, R1)
        pbar.set_description("Loss {:.5f}".format(loss.detach().cpu().item()) )
        loss.backward()
        optimizer.step()
    torch.save(latent2latent.state_dict(), "artifacts/latent2latent.pt")

In [None]:
latent2latent.load_state_dict(torch.load("artifacts/latent2latent.pt"))

#### Train with PytorchLightning

In [None]:
pl_latent2latent = LightningLatent2Latent(coach.net, regressor, R1).to(device)
pl_latent2latent.load_latent_state("artifacts/latent2latent.pt")
trainer = pl.Trainer(gpus=1, max_epochs=10, weights_save_path="artifacts/weights/latent2latent_pl")
trainer.fit(model=pl_latent2latent, train_dataloaders=coach.train_dataloader)

## 4. Inference

### 4.1 Load a sample

In [None]:
import matplotlib.pyplot as plt
import numpy as np

seg, bscan = coach.train_dataset[69]
seg_im = np.argmax(seg, axis=0)
plt.imshow(seg_im)

### 4.2 Change an attribute & correct the modification

Get style vector

In [None]:
segs = seg.unsqueeze(0).float().cuda()
with torch.no_grad():
    _, w_latents, w_codes = get_latent(coach.net, segs, device)

In [None]:
attribute = attribute_label_from_segmentation(seg, normalize=True)
modified_attribute, actual_change = modify_attribute(attribute, R1, change_nr=1, change=1)
modified_attribute = torch.Tensor(modified_attribute).unsqueeze(0)
attribute = torch.Tensor(attribute).unsqueeze(0)
delta_attributes = modified_attribute - attribute
print(actual_change)

### 4.3 Generate a new latent

In [None]:
delta_attributes = delta_attributes.to(device)
w_latents = w_latents.to(device)
with torch.no_grad():
    w_n =  latent2latent(w_latents, delta_attributes)
w_n = 0.7*w_latents + 0.3*w_n

### 4.5 Generate a new sample

In [None]:
generated_images = coach.net(w_latents, input_code=True).detach().cpu().numpy()
plt.imshow(generated_images[0][0])

In [None]:
original_reconstruction = coach.net(segs).detach().cpu().numpy()
plt.imshow(original_reconstruction[0][0])