# **Context Encoders**


## Load libraries, models

In [None]:
import numpy as np
from utils.load_dataset import get_all_datasets
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import matplotlib.pyplot as plt
from model import ContextEncoder, Discriminator
from classifier import create_wildfire_classifier
from utils.utils_gan import get_mask, apply_mask, train_one_epoch, load_checkpoint, plot_comparison

In [None]:
def extract_masked(true, pred, mask, title):
    pred = pred.detach().cpu().numpy()
    zone_true = np.zeros((50,50,3))
    zone_pred = np.zeros((50,50,3))
    for i in range(3):
        zone_true[:,:,i] = true[0,i,...].flatten()[np.flatnonzero(1-mask[0,0])].reshape((50,50))
        zone_pred[:,:,i] = pred[0,i,...].flatten()[np.flatnonzero(1-mask[0,0])].reshape((50,50))

    plot_comparison(zone_true, zone_pred, title)

## Load data

In [None]:
dataset_path = Path('../data')
pretrain_path = dataset_path / 'train'
val_path = dataset_path / 'valid'
test_path = dataset_path / 'test'
data_transforms = {
    'pretrain': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'valid': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
    ]),
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pretrain_dataset, train_dataset, val_dataset, test_dataset = get_all_datasets(pretrain_path=pretrain_path,
                                                                              val_path=val_path,
                                                                              test_path=test_path,
                                                                              transforms_dict=data_transforms)

## Load model

In [None]:
pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=16)
model = ContextEncoder()
discriminator = Discriminator()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path_check = '../checkpoints/context_encoder/context_encoder.pt'
load_checkpoint(path_check, model, discriminator)

## Visualize ContextEncoder performances

In [None]:
ds = train_dataset
idx = np.random.randint(0,len(ds))
true = ds[idx][0][None]
mask = get_mask(true.shape)
input_masked, output_masked_gt = apply_mask(true, mask)
pred = model(input_masked)
reconstructed = pred*(1-mask) + true*mask
plot_comparison(input_masked[0].permute(1,2,0).detach().cpu().numpy(),pred[0].permute(1,2,0).detach().cpu().numpy(), 'Input masked vs Recontructed Image')
plt.show()

extract_masked(true,reconstructed, mask, 'Masked zone vs Reconstructed zone')