<a href="https://colab.research.google.com/github/eduardojdiniz/Buzznauts/blob/master/scripts/wandb_VAE_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VAE Model: Predicting fMRI responses from Algonauts2021 dataset

**Goal:** Prepare submission for Algonauts 2021 challenge using the VAE model.

### Setup

In [1]:
# Install dependencies and Download Buzznauts
%%capture

!pip install duecredit --quiet
!pip install torchinfo --quiet
!pip install nilearn --quiet
!pip install decord --quiet
!pip install git+https://github.com/eduardojdiniz/Buzznauts --quiet
!pip install wandb --quiet

In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [3]:
# Set videos and annotation file path
import os
import os.path as op
from pathlib import Path
import wandb

drive_root = "/content/drive/MyDrive/Buzznauts"

# Data paths
fmri_dir = op.join(drive_root, "data", "fmri")
stimuli = op.join(drive_root, "data", "stimuli") 
videos_dir = op.join(stimuli, "videos")
frames_dir = op.join(stimuli, "frames")
annotation_file = op.join(frames_dir, 'annotations.txt')
pretrained_dir = op.join(drive_root, "data", "pretrained")
pretrained_vaegan = op.join(pretrained_dir, "vaegan_enc_weights.pickle")

# Visualizations path
viz_dir = "/content/visualizations"
viz_vae_dir = op.join(viz_dir, "vae")

# Model path
models_dir = "/content/models"
model_vae_dir = op.join(models_dir, "vae")

# Results paths
results_dir = "/content/results/vae"

In [4]:
# Import interactive tools
from tqdm.notebook import tqdm, trange

In [5]:
# Import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision
from torchvision import datasets, transforms

In [6]:
from Buzznauts.utils import set_seed, set_device, seed_worker, set_generator
from Buzznauts.data.utils import plot_video_frames
from Buzznauts.data.videodataframe import VideoFrameDataset, ImglistToTensor, FrameDataset

## Model Architecture

In [7]:
class ConvVarAutoEncoder(nn.Module):
    def __init__(self, K, data_shape=(3, 128, 128), num_filters=[192, 256, 384, 512, 768], filter_size=3):
        super(ConvVarAutoEncoder, self).__init__()
        ## 5 Conv Layers
        filter_reduction = 5 * (filter_size // 2)

        self.shape_after_conv = calc_output_size(data_shape, filter_size, num_filters)

        self.flat_shape = self.shape_after_conv[0] * self.shape_after_conv[1] * self.shape_after_conv[2]

        # Double for each additional layer of Conv
        flat_size_after_conv = self.shape_after_conv[0] * self.shape_after_conv[1] * self.shape_after_conv[2]

        # ENCODER
        self.q_bias = BiasLayer(data_shape)
        self.q_conv_1 = nn.Conv2d(data_shape[0], num_filters[0], filter_size)
        self.q_conv_2 = nn.Conv2d(num_filters[0], num_filters[1], filter_size)
        self.q_conv_3 = nn.Conv2d(num_filters[1], num_filters[2], filter_size)
        self.q_conv_4 = nn.Conv2d(num_filters[2], num_filters[3], filter_size)
        self.q_conv_5 = nn.Conv2d(num_filters[3], num_filters[4], filter_size)
        self.q_flatten = nn.Flatten()
        self.q_fc_phi = nn.Linear(self.flat_shape, K+1)

        # DECODER
        self.p_fc_upsample = nn.Linear(K, self.flat_shape)
        self.p_unflatten = nn.Unflatten(-1, self.shape_after_conv)
        self.p_deconv_1 = nn.ConvTranspose2d(num_filters[4], num_filters[3], filter_size)
        self.p_deconv_2 = nn.ConvTranspose2d(num_filters[3], num_filters[2], filter_size)
        self.p_deconv_3 = nn.ConvTranspose2d(num_filters[2], num_filters[1], filter_size)
        self.p_deconv_4 = nn.ConvTranspose2d(num_filters[1], num_filters[0], filter_size)
        self.p_deconv_5 = nn.ConvTranspose2d(num_filters[0], data_shape[0], filter_size)

        self.p_bias = BiasLayer(data_shape)

        # Define a special extra parameter to learn scalar sig_x for all pixels
        self.log_sig_x = nn.Parameter(torch.zeros(()))


    def infer(self, x):
        """Map (batch of) x to (batch of) phi which can then be passed to
        rsample to get z
        """
        s = self.q_bias(x)
        s = F.elu(self.q_conv_1(s))
        s = F.elu(self.q_conv_2(s))
        s = F.elu(self.q_conv_3(s))
        s = F.elu(self.q_conv_4(s))
        s = F.elu(self.q_conv_5(s))
        flat_s = s.view(s.size()[0], -1)
        phi = self.q_fc_phi(flat_s)
        return phi


    def generate(self, zs):
        """Map [b,n,k] sized samples of z to [b,n,p] sized images
        """
        # Note that for the purposes of passing through the generator, we need
        # to reshape zs to be size [b*n,k]
        b, n, k = zs.size()
        s = zs.view(b*n, -1)
        s = F.elu(self.p_fc_upsample(s)).view((b*n,) + self.shape_after_conv)
        s = F.elu(self.p_deconv_1(s))
        s = F.elu(self.p_deconv_2(s))
        s = F.elu(self.p_deconv_3(s))
        s = F.elu(self.p_deconv_4(s))
        s = self.p_deconv_5(s)
        s = self.p_bias(s)
        mu_xs = s.view(b, n, -1)
        return mu_xs

    
    def decode(self, zs):
        # Included for compatability with conv-AE code
        return self.generate(zs.unsqueeze(0))

    
    def forward(self, x):
        # VAE.forward() is not used for training, but we'll treat it like a
        # classic autoencoder by taking a single sample of z ~ q
        phi = self.infer(x)
        zs = rsample(phi, 1)
        return self.generate(zs).view(x.size())
    

    def elbo(self, x, n=1, epsilon = 0.1):
        """Run input end to end through the VAE and compute the ELBO using n
        samples of z
        """
        phi = self.infer(x)
        zs = rsample(phi, n)
        mu_xs = self.generate(zs)
        return log_p_x(x, mu_xs, self.log_sig_x.exp()) - kl_q_p(zs, phi, epsilon), mu_xs.data

    
    def load_my_state_dict(self, state_dict):
        curr_state=self.state_dict()

        for name, param in state_dict.items():
            if name not in curr_state:
                continue
            if isinstance(param, torch.Tensor):
                param = param.data
            curr_state[name].copy_(param)
    
    
class BiasLayer(nn.Module):
    def __init__(self, shape):
        super(BiasLayer, self).__init__()
        init_bias = torch.zeros(shape)
        self.bias = nn.Parameter(init_bias, requires_grad=True)

    def forward(self, x):
        return x + self.bias
    
    
def calc_output_size(input_size, kernel_size, kchannels, padding=0, stride=1):
    output_size = input_size
    for kc in kchannels:
        output_height = (output_size[1] + padding + padding - kernel_size) / (stride) + 1
        output_width = (output_size[2] + padding + padding - kernel_size) / (stride) + 1

        output_size = [kc, int(output_height), int(output_width)]

    return tuple(output_size)

### ELBO loss helper functions

In [8]:
def kl_q_p(zs, phi, epsilon=0.1):
    """Given [b,n,k] samples of z drawn from q, compute estimate of KL(q||p).
    phi must be size [b,k+1]

    This uses mu_p = 0 and sigma_p = 1, which simplifies the log(p(zs)) term to
    just -1/2*(zs**2)
    """
    b, n, k = zs.size()
    mu_q, log_sig_q = phi[:,:-1], phi[:,-1]
    log_p = -0.5*(zs**2)
   
    log_q = -0.5*(zs - mu_q.view(b,1,k))**2 / (log_sig_q.exp().view(b,1,1)**2 + epsilon) - log_sig_q.view(b,1,-1)
    # Size of log_q and log_p is [b,n,k]. Sum along [k] but mean along [b,n]
    return (log_q - log_p).sum(dim=2).mean(dim=(0,1))


def log_p_x(x, mu_xs, sig_x):
    """Given [batch, ...] input x and [batch, n, ...] reconstructions, compute
    pixel-wise log Gaussian probability

    Sum over pixel dimensions, but mean over batch and samples.
    """
    b, n = mu_xs.size()[:2]
    # Flatten out pixels and add a singleton dimension [1] so that x will be
    # implicitly expanded when combined with mu_xs
    x = x.reshape(b, 1, -1)
    _, _, p = x.size()
    squared_error = (x - mu_xs.view(b, n, -1))**2 / (2*sig_x**2)

    # Size of squared_error is [b,n,p]. log prob is by definition sum over [p].
    # Expected value requires mean over [n]. Handling different size batches
    # requires mean over [b].
    return -(squared_error + torch.log(sig_x)).sum(dim=2).mean(dim=(0,1))


def rsample(phi, n_samples):
    """Sample z ~ q(z;phi)
    Ouput z is size [b,n_samples,K] given phi with shape [b,K+1]. The first K
    entries of each row of phi are the mean of q, and phi[:,-1] is the log
    standard deviation
    """
    b, kplus1 = phi.size()
    k = kplus1-1
    mu, sig = phi[:, :-1], phi[:,-1].exp()
    eps = torch.randn(b, n_samples, k, device=phi.device)
    return eps*sig.view(b,1,1) + mu.view(b,1,k)

### Model Weights

In [9]:
import pickle

In [10]:
def load_vaegan_weights(model, pretrained_path):
    # load pretrained weights
    pretrained_fn = open(pretrained_path,'rb')
    pretrained = pickle.load(pretrained_fn)

    # have a look what's in the pretrained file
    old_keynames=[]
    for key, value in pretrained.items():
        old_keynames.append(key) 

    # get the keynames of our model
    curr_state=model.state_dict()
    new_keynames=[]
    for key, value in curr_state.items():
        if key.startswith('q_conv'):
            new_keynames.append(key)

    # change the names of the pretrained model to match our model
    for i in range(len(old_keynames)):
        pretrained[new_keynames[i]] = pretrained[old_keynames[i]]
        del pretrained[old_keynames[i]]

    # change size & make the weights a torch
    # In TF, Conv2d filter shape is [filter_height, filter_width, in_channels, out_channels],
    # while in Pytorch is (out_channels, in_channels, kernel_size[0], kernel_size[1])
    # So we need to permute [3,2,0,1]
    for key, value in pretrained.items():
        if len(value.shape)==4:
            new_val=torch.tensor(value)
            new_val=new_val.permute(3,2,0,1)
        else: 
            new_val=torch.tensor(value)
    
        pretrained[key] = new_val
    
    return pretrained

In [11]:
def reset_weights(model):
    """Try resetting model weights to avoid weight leakage.
    
    Parameters
    ----------
    model: torch.nn.Module
    """
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

### Main

In [12]:
from sklearn.model_selection import KFold

In [13]:
# Configuration

# Set seed to the random generators to ensure reproducibility
seed = set_seed()

# Set computational device (cuda if GPU is available, else cpu)
device = set_device()
      
# Number of folds for cross-validation 
k_folds = 5

# Define the K-fold Cross Validator
kfold = KFold(n_splits=k_folds, shuffle=True)

# Number of epochs
num_epochs = 10 

# Batch size
batch_size = 4 

# Size of the VAE's latent space  
K_VAE = 128 

#---------------
# Create Dataset
#---------------

# Number of splits in each video
num_segments = 5

# Number of frames per split
frames_per_segment = 6

# Total number of training frames
total_frames = num_segments * frames_per_segment

# Frame size
frame_size = 32 
width = frame_size
height = frame_size

# Num of channels
num_channels = 3

# Data shape
data_shape = (num_channels, frame_size, frame_size)

# Input size
input_size = (batch_size, num_channels, frame_size, frame_size)

# Tensorize convert PIL images to tensors and resize each frame to frame_size
tensorize = transforms.Compose([
    ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
    transforms.Resize(frame_size), # image batch, resize smaller edge to 128
])

# Preprocess center crop to 100x128, normalize and apply random affine
# and horizontal flips to each frame
preprocess = transforms.Compose([
    transforms.CenterCrop((frame_size, frame_size)), # image batch, center crop to square 100x128
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomAffine(degrees=15, translate=(0.05, 0.05), scale=(0.78125, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5)
])

# Videoframe dataset: each sample is of size (FRAMES X CHANNELS X HEIGHT X WIDTH)
videoframe_dataset = VideoFrameDataset(
    root_path=frames_dir,
    annotationfile_path=annotation_file,
    num_segments=num_segments,
    frames_per_segment=frames_per_segment,
    imagefile_template='img_{:05d}.jpg',
    transform=tensorize,
    random_shift=False,
    test_mode=False
)

# Frame dataset: each sample is of size (CHANNELS X HEIGHT X WIDTH)
frame_dataset = FrameDataset(
    videoframedataset=videoframe_dataset,
    transform=preprocess
)

Random seed 724233587 has been set.
GPU is not enabled.


In [14]:
# Model summary
from torchinfo import summary

convVAE = ConvVarAutoEncoder(data_shape=data_shape, K=K_VAE)
wandb.init()
wandb.watch(convVAE)
summary(convVAE, input_size=input_size)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Layer (type:depth-idx)                   Output Shape              Param #
ConvVarAutoEncoder                       --                        --
├─BiasLayer: 1-1                         [4, 3, 32, 32]            3,072
├─Conv2d: 1-2                            [4, 192, 30, 30]          5,376
├─Conv2d: 1-3                            [4, 256, 28, 28]          442,624
├─Conv2d: 1-4                            [4, 384, 26, 26]          885,120
├─Conv2d: 1-5                            [4, 512, 24, 24]          1,769,984
├─Conv2d: 1-6                            [4, 768, 22, 22]          3,539,712
├─Linear: 1-7                            [4, 129]                  47,950,977
├─Linear: 1-8                            [4, 371712]               47,950,848
├─ConvTranspose2d: 1-9                   [4, 512, 24, 24]          3,539,456
├─ConvTranspose2d: 1-10                  [4, 384, 26, 26]          1,769,856
├─ConvTranspose2d: 1-11                  [4, 256, 28, 28]          884,992
├─ConvTranspose2d: 1

#### Train!

In [None]:
 torch.autograd.set_detect_anomaly(True)

# Save loss values during training for each fold
loss_train = {f'Fold_{i}': [] for i in range(1, k_folds+1)}
# Save loss during validation for each fold
loss_val = {f'Fold_{i}': [] for i in range(1, k_folds+1)}
# Save overall loss during validation for each fold
loss_val_overall = {f'Fold_{i}': None for i in range(1, k_folds+1)}

# K-fold Cross Validation model evaluation
for fold, (train_idx, val_idx) in enumerate(kfold.split(frame_dataset)):
    print(f'FOLD {fold+1}')
    print('-------------------------')
    
    # Sample elements randomly from a given list of idx, no replacement
    train_subsampler = SubsetRandomSampler(train_idx)
    val_subsampler = SubsetRandomSampler(val_idx)
    
    # Define data loaders for training and testing data in this fold
    train_loader = DataLoader(
        dataset=frame_dataset,
        batch_size=batch_size,
        sampler=train_subsampler,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=set_generator())
    
    val_loader = DataLoader(
        dataset=frame_dataset,
        batch_size=batch_size,
        sampler=val_subsampler,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=set_generator())
    
    # Instantiate network
    convVAE = ConvVarAutoEncoder(data_shape=data_shape, K=K_VAE)
    convVAE.apply(reset_weights)

    # Load Pretrained weights
    # pretrained = load_vaegan_weights(convVAE, pretrained_vaegan)
    # convVAE.load_my_state_dict(pretrained)
    
    # # Freezing layers
    # freeze_idx = [2, 3, 4, 5]
    # for idx, param in enumerate(convVAE.parameters()): 
    #     if idx in freeze_idx: param.requires_grad = False

    # Initialize optimizer
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, convVAE.parameters()),
                                 lr=3e-4, weight_decay=0)
    
    convVAE.to(device)
    convVAE.train()
    
    # Run the training loop for defined number of epochs
    for epoch in trange(num_epochs, desc='Epochs'):
        
        # Set current loss value
        current_loss = 0.0
        
        # Iterate over the DataLoader for training data
        for i, (frame, label) in enumerate(tqdm(train_loader, 
                                                total=len(train_loader) // batch_size,
                                                desc='Batches', leave=False)):
            frame = frame.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Compute loss
            loss, muxs = convVAE.elbo(frame)
            loss = -loss
            # Perform backward pass
            loss.backward()

            # Perform optimization
            optimizer.step()

            # Saving loss
            loss_train[f'Fold_{fold+1}'].append(-loss.item())

            # Print statistics
            current_loss += loss.item()
            if i % 100 == 99:
                print('Loss after mini-batch %5d: %.3f' %
                      (i + 1, current_loss / 100))
                current_loss = 0.0
            
            example_frames = [wandb.Image(f, caption="Input {idx}".format(idx=idx)) for idx, f in enumerate(frame)]
            example_gens = [wandb.Image(g.reshape(3, 32, 32), caption="Generated {idx}".format(idx=idx)) for idx, g in enumerate(muxs)]

            wandb.log({"Loss: " : loss.item(),
                       "Input: ": example_frames,
                       "Generated: ": example_gens})
        
    # Evaluation for this fold
    convVAE.eval()
    correct, total = 0, 0
    with torch.no_grad():
        # Iterate over the DataLoader for validation data
        for i, (frame, label) in enumerate(tqdm(val_loader, 
                                                total=len(val_loader) // batch_size,
                                                desc='Batches', leave=False)):
            # Compute loss
            loss = -convVAE.elbo(frame)

            # Saving loss
            loss_val[f'Fold_{fold+1}'].append(-loss.item())

        # Print overall fold loss 
        loss_val_overall[f'Fold_{fold+1}'] = sum(loss_val[f'Fold_{fold+1}'])
        print('Total loss for fold %d: %d %%' % (fold, results[f'Fold_{fold+1}']))
        print('--------------------------------')
    
# Print fold results
print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
print('----------------------------------------------------')
overall_sum = 0.0
for key, value in loss_val_overall.items():
    print(f'Fold {key+1}: {value} %')
    overall_sum += value
print(f'Average: {overall_sum/len(loss_val_overall.items())} %')

FOLD 1
-------------------------


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/1653 [00:00<?, ?it/s]

In [68]:
elbo_vals

NameError: ignored

In [68]:
3*32*32


3072

In [52]:
for f in frame: print(f

torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])


In [72]:
mu_xs[0].reshape(3, 32, 32)

torch.Size([3, 32, 32])

In [81]:
mu_xs.data

tensor([[[nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan]]])