In [1]:
import argparse

import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from tqdm import tqdm

from lateral_connections import LateralModel, VggModel
from lateral_connections import VggWithLCL
from lateral_connections import MNISTCDataset
from lateral_connections.loaders import get_loaders, load_mnistc
from lateral_connections.character_models import SmallVggWithLCL

import datetime

In [2]:
config = {
    'num_classes': 10,
    'learning_rate': 1e-3,
    'dropout': 0.2,
    'num_epochs': 4,
    'batch_size': 10,
    'use_lcl': True,
    'num_multiplex': 4,
    'lcl_alpha': 1e-3,
    'lcl_theta': 0.2,
    'lcl_eta': 0.0,
    'lcl_iota': 0.2
}

model = SmallVggWithLCL(config['num_classes'], learning_rate=config['learning_rate'], dropout=config['dropout'],
    num_multiplex=config['num_multiplex'], do_wandb=False, run_identifier="",
    lcl_alpha=config['lcl_alpha'], lcl_eta=config['lcl_eta'], lcl_theta=config['lcl_theta'], lcl_iota=config['lcl_iota'])

model

SmallVggWithLCL(
  (features): Sequential(
    (pool1): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (pool2): Sequential(
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (pool3): Sequential(
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3

In [3]:
dataset = load_mnistc('identity')
img, lbl = dataset[0]
print(img.shape, lbl)

torch.Size([1, 224, 224]) tensor(7)


In [4]:
x = model.features.pool1(img.unsqueeze(0).to(model.device))
print(x.shape)

x = model.features.pool2(x)
print(x.shape)

x = model.features.pool3(x)
print(x.shape)

x = model.features.lcl3(x)
print(x.shape)

x = model.features.pool4(x)
print(x.shape)

x = model.features.pool5(x)
print(x.shape)

torch.Size([1, 64, 112, 112])
torch.Size([1, 128, 56, 56])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 28, 28])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 7, 7])
