In [1]:
import os

import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from lateral_connections import LateralModel, VggModel
from lateral_connections import VggWithLCL
from lateral_connections import MNISTCDataset
from lateral_connections.loaders import get_loaders, load_mnistc
from lateral_connections.character_models import SmallVggWithLCL, VGGReconstructionLCL
from lateral_connections.torch_utils import *
from lateral_connections.model_factory import *

import datetime

dataset_identity = load_mnistc(dirname='identity')
dataset_line = load_mnistc(dirname='line')
dataset_gaussian_noise = load_mnistc(dirname='gaussian_noise')

In [2]:
model = load_model_by_key('vgg19r_lcl')

In [3]:
train_loader, _, _, _ = get_loaders(10, 'line')

### Train a new VGG19R+LCL@5 model
The kernels are at each iteration saved as images to show how they change over time.

### Train a new VGG19R+LCL@5 model
The kernels are at each iteration saved as images to show how they change over time.

#### This time however: Start with zero-initialized kernels!

In [None]:
num_kernels = 10
plot_scale = 4

counter = 0

change_data = []
prev_data = np.zeros((num_kernels, num_kernels, 5, 5))
k_data = np.zeros((num_kernels, num_kernels, 5, 5))

for i, (images, labels) in tqdm(enumerate(train_loader, 0), total=len(train_loader), desc='Training'):
    prev_data = np.copy(k_data)
    k_data = model.features.lcl.K.cpu().detach().numpy()[:num_kernels, :num_kernels, ...]
    change_data.append(np.sum(k_data - prev_data))
    
    # Kernel Plot
    fig, axs = plt.subplots(num_kernels, num_kernels, figsize=(plot_scale*num_kernels, plot_scale*num_kernels))
    for i in range(num_kernels):
        for j in range(num_kernels):
            axs[i,j].imshow(k_data[j, i, ...])
            axs[i,j].set_title(str(j) + ' --> ' + str(i))
    plt.suptitle('Iteration ' + str(counter), size=24)
    plt.tight_layout()
    #plt.show()
    plt.savefig('models/vgg_reconstructed_lcl/lcl_kernel_evolution/' + str(counter) + '.png')
    plt.close()

    # Change Plot
    if len(change_data) > 1:
        plt.figure()
        plt.plot(change_data[1:])
        plt.savefig('models/vgg_reconstructed_lcl/lcl_kernel_evolution/_change.png')
        plt.close()
    
    images = images.to(model.device)
    labels = labels.to(model.device)

    outputs = model(images)
    loss = model.loss_fn(outputs, labels)

    model.optimizer.zero_grad()
    loss.backward()
    model.optimizer.step()
    
    counter += 1
    

Training:   8%|████▉                                                            | 380/5000 [40:04<7:51:38,  6.13s/it]

In [None]:
plt.plot(change_data[1:])

### Combine images to get gif

In [7]:
import os 

p1 = 'models/vgg_reconstructed_lcl/lcl_kernel_evolution/20220605_170000'
p2 = 'models/vgg_reconstructed_lcl/lcl_kernel_evolution/20220506_180000_ZeroInitialization'

p1_files = []
for f in os.listdir(p1):
    if f != '_change.png':
        p1_files.append((int(f.split('.png')[0]), f))
        
sorted(p1_files)

p2_files = []
for f in os.listdir(p2):
    if f != '_change.png':
        p2_files.append((int(f.split('.png')[0]), f))
        
sorted(p2_files)



[(0, '0.png'),
 (1, '1.png'),
 (2, '2.png'),
 (3, '3.png'),
 (4, '4.png'),
 (5, '5.png'),
 (6, '6.png'),
 (7, '7.png'),
 (8, '8.png'),
 (9, '9.png'),
 (10, '10.png'),
 (11, '11.png'),
 (12, '12.png'),
 (13, '13.png'),
 (14, '14.png'),
 (15, '15.png'),
 (16, '16.png'),
 (17, '17.png'),
 (18, '18.png'),
 (19, '19.png'),
 (20, '20.png'),
 (21, '21.png'),
 (22, '22.png'),
 (23, '23.png'),
 (24, '24.png'),
 (25, '25.png'),
 (26, '26.png'),
 (27, '27.png'),
 (28, '28.png'),
 (29, '29.png'),
 (30, '30.png'),
 (31, '31.png'),
 (32, '32.png'),
 (33, '33.png'),
 (34, '34.png'),
 (35, '35.png'),
 (36, '36.png'),
 (37, '37.png'),
 (38, '38.png'),
 (39, '39.png'),
 (40, '40.png'),
 (41, '41.png'),
 (42, '42.png'),
 (43, '43.png'),
 (44, '44.png'),
 (45, '45.png'),
 (46, '46.png'),
 (47, '47.png'),
 (48, '48.png'),
 (49, '49.png'),
 (50, '50.png'),
 (51, '51.png'),
 (52, '52.png'),
 (53, '53.png'),
 (54, '54.png'),
 (55, '55.png'),
 (56, '56.png'),
 (57, '57.png'),
 (58, '58.png'),
 (59, '59.png'),


In [11]:
import glob
from PIL import Image

imgs = (Image.open(p1 + '/' + p[1]) for p in sorted(p1_files))
img = next(imgs)
img.save(fp=p1 + '.gif', format='GIF', append_images=imgs, save_all=True, duration=20*1000, loop=0)

In [None]:

imgs = (Image.open(p2 + '/' + p[1]) for p in sorted(p2_files))
img = next(imgs)
img.save(fp=p2 + '.gif', format='GIF', append_images=imgs, save_all=True, duration=20*1000, loop=0)

In [19]:
import cv2

images = [cv2.imread(p2 + '/' + p[1]) for p in sorted(p2_files)]
video = cv2.VideoWriter(p2 + '.avi', cv2.VideoWriter_fourcc(*'XVID'), 24, (1000,1000))

for image in images:
    video.write(image)
