In [1]:
%load_ext autoreload 
%autoreload 2

import os
import random
import numpy as np
import scipy.linalg as sl
from PIL import Image
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns
from IPython import display

import torch
from torch import nn, distributions as dist, autograd
from torch.func import jacfwd
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Resize, CenterCrop, RandomHorizontalFlip, RandomVerticalFlip, ToTensor, Normalize
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
# torch.set_default_device("cuda")
torch.set_default_dtype(torch.float32)
plt.style.use('seaborn-v0_8')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

DATASET_PATH = "/mnt/dl/datasets/Oxford102FlowersSplits/"
os.environ["KERAS_BACKEND"] = "tensorflow"
LABELS = {i: k.strip() for i, k in enumerate(open(os.path.join(DATASET_PATH, "names.txt")))}
img_size = 112
batch_size = 32
num_classes = len(LABELS)
patch_size = 16
num_patches = img_size ** 2 / patch_size **2

In [3]:
class FlowerDataset(Dataset):
    def __init__(self, path, split, cache=True, transforms=None):
        super().__init__()
        self.load_data(path, split)
        self.samples = dict()
        self.transforms = transforms
        
    def load_data(self, path, split):
        path = os.path.join(path, split, )
        img_files = os.listdir(os.path.join(path, "jpeg"))
        img_files = sorted(img_files, key=lambda x: int(x.replace(".jpeg", "")))
        img_files = list(img_files)
        
        labels = list(open(os.path.join(path, "label", "label.txt"),))
        self.labels = [int(l.strip()) for l in labels]
        
        self.img_files = [os.path.join(path, "jpeg", name) for name in img_files]
    
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, index):
        if index not in self.samples:
            self.load_sample(index)
        sample = self.samples[index]
        if self.transforms is not None:
            sample = self.transforms(sample)

        return (sample, self.labels[index])
        
    def load_sample(self, idx):
        img = Image.open(self.img_files[idx])
        img = np.array(img).astype(np.float32)
        self.samples[idx] = img
        return True


In [4]:
train_ds = FlowerDataset(DATASET_PATH, "train", transforms=Compose([
    ToTensor(),
    Resize((img_size, img_size)),
    RandomHorizontalFlip(0.1),
    RandomVerticalFlip(0.),
    Normalize(0., 255.0)
]))
val_ds = FlowerDataset(DATASET_PATH, "validation", transforms=Compose([
    ToTensor(),
    Resize((img_size, img_size)),
    Normalize(0., 255.0)
    
]))

test_ds = FlowerDataset(DATASET_PATH, "test", transforms=Compose([
    ToTensor(),
    Resize((img_size, img_size)),
    Normalize(0., 255.0)
    
]))

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=True, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=16)

## EBM Resnet Model

In [5]:
class Residual(nn.Module):
    
    def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding='same'):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, 
                               out_channels=out_channel,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, 
                                out_channels=out_channel,
                                padding="same", kernel_size=kernel_size)
        self.bn2 = nn.BatchNorm2d(out_channel)

        self.downsample = None
        
        if out_channel != in_channel:
            self.downsample = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=2,
                                        bias=None)
        
        self.apply(self.initialize_parameters)
    
    def initialize_parameters(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.zeros_(m.bias.data)
                
    def forward(self, x):
        prev_x = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.gelu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.gelu(x)
        
        if self.downsample:
            prev_x = self.downsample(prev_x)
        
        return x + prev_x
        
class ResBlock(nn.Module):
    
    def __init__(self, nblk, in_channel, out_channel, kernel_size, stride):
        super().__init__()
        
        self.net = nn.ModuleList()
        self.net.append(Residual(in_channel, out_channel, kernel_size, stride,
                                 padding='same' if stride==1 else 'valid'))
        for i in range(1, nblk):
            self.net.append(Residual(out_channel, out_channel, kernel_size, stride=1))
    
    def forward(self, x):
        for net in self.net:
            x = net(x)
        return x

# class GlobalPooling2D(nn.Module):
        
#     def forward(self, x):
#         return x.mean([2, 3])
    

class EBM(nn.Module):
    
    def __init__(self, num_classes):
        super().__init__()
        
        filters = [64, 64, 128, 256, 512]
        nblocks = [2, 4, 4, 3]
        kernels = [3, 3, 3, 3]
        strides = [1, 2, 2, 2]
        self.conv1 = nn.Conv2d(3, filters[0], 3, padding=1)
        self.bn1 = nn.BatchNorm2d(filters[0])
        self.maxpool = nn.MaxPool2d(2,)
        self.res_blocks = nn.ModuleList()
        for i, nblk in enumerate(nblocks):
            self.res_blocks.append(ResBlock(nblk, 
                                            in_channel=filters[i],
                                            out_channel=filters[i+1], 
                                            kernel_size=kernels[i],
                                            stride=strides[i]))
        # self.avg_pool = GlobalPooling2D()
        self.dense = nn.Sequential(nn.Linear(filters[-1], filters[-1] * 4),
                                   nn.ReLU(),
                                   nn.Dropout(),
                                   nn.Linear(filters[-1] * 4, num_classes)
                                   )
        with torch.no_grad():
            for m in self.dense.modules():
                if isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, -0.05, 0.05)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
    

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        
        x = self.maxpool(x)
        
        x = self.res_blocks[0](x)
        x = self.res_blocks[1](x)
        x = self.res_blocks[2](x)
        x = self.res_blocks[3](x)
        
        # x = self.avg_pool(x)
        x = F.avg_pool2d(x, x.size(2), )
        x =  x.view((x.size(0), -1))
        x = self.dense(x)
        return x


In [None]:
ebm = EBM(num_classes)
ebm.cuda()

In [6]:
from torchvision.models import resnet34

In [7]:
ebm = resnet34(num_classes)
ebm.fc = nn.Linear(512, num_classes)
with torch.no_grad():
    for p in ebm.parameters():
        if len(p.size()) == 1:
            nn.init.zeros_(p)
            continue
        nn.init.normal_(p, -0.0005, 0.005)



In [8]:
ebm(torch.randn((2, 3, 224, 224)).cuda()).size()

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [9]:
ebm.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## Pretraining

In [10]:
class ClassificationTrainer:
    
    def __init__(self, model, train_loader, val_loader=None, epochs=1, eval_epochs=0, savepath=None):
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss(reduction="mean")
        self.train_loader = train_loader 
        self.val_loader = val_loader
        self.epochs = epochs
        self.eval_epochs = eval_epochs
        self.savepath = savepath
        self.eval_savepath = os.path.join(self.savepath, "eval")
        self.model_savepath = os.path.join(self.savepath, "model")
        
        os.makedirs(self.model_savepath, exist_ok=True)
        os.makedirs(self.eval_savepath, exist_ok=True)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        
        
    def train(self,):
        self.train_losses = []
        self.acc = []
        best_loss = 0.
        for i in range(self.epochs):
            ep_losses = self.run_epoch(i)
            self.train_losses.extend(ep_losses)
            if self.eval_epochs > 0 and i % self.eval_epochs == 0:
                acc = self.eval_epoch(i)
                if acc > best_loss:
                    best_loss = acc
                    self.save_model(fname="best_model", epoch=i)
                print("**" * 20 + f"Epoch {i} acc: {acc}")
                self.acc.append(acc.item())
        print("Succesfully trained...")
        self.save_model(f"last_model", self.epochs)
        return True
    
    def save_model(self, fname, epoch=0):
        torch.save({"model": self.model.state_dict(),
                    "optimizers": self.optimizer.state_dict(),
                    "losses": self.train_losses,
                    "epoch": epoch
                    
            }, os.path.join(self.model_savepath, fname))

    def run_epoch(self, epoch):
        losses = []
        self.model.train()
        for j, (img, label) in enumerate(self.train_loader):
            img, label = img.cuda(), label.cuda()
            pred = self.model(img)
            loss = self.loss_fn(pred, label)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            losses.append(loss.item())
            # print(sum([p.sum() for p in ebm.parameters()]), img.mean(), label.float().mean())
            if j % 5 == 0:
                print(f"Epoch {epoch}, step {j}, loss: {np.mean(losses)}")
                
        return losses
    
    def eval_epoch(self, epoch):
        savepath = os.path.join(self.eval_savepath, f"{epoch:05d}")
        os.makedirs(savepath, exist_ok=True)
        self.model.eval()
        print(f"Evaluating {epoch}")
        ep_acc = []
        with torch.no_grad():
            for k, (img, label) in enumerate(self.val_loader):
                img, label = img.cuda(), label.cuda()
                pred = self.model(img)
                acc = self.get_accuracy(pred, label)
                ep_acc.extend(acc)
        
        self.model.train()
        
        return  torch.stack(ep_acc).mean() * 100.
    

    def get_accuracy(self, input, target):
        inp_argmax = input.argmax(axis=1)
        acc = inp_argmax == target
        acc = acc.to(torch.float32)
        
        return acc
            

In [11]:
flower_classifier = ClassificationTrainer(ebm, train_loader=train_loader, val_loader=val_loader,
                                      epochs=100, eval_epochs=20, savepath="/mnt/dl/generation/ebm/classification")

In [12]:
sum([p.sum() for p in ebm.parameters()])

tensor(-10665.7705, device='cuda:0', grad_fn=<AddBackward0>)

In [13]:
flower_classifier.train()



Epoch 0, step 0, loss: 4.62497091293335
Epoch 0, step 5, loss: 4.624984502792358
Epoch 0, step 10, loss: 4.625018163160845
Epoch 0, step 15, loss: 4.6250220239162445
Epoch 0, step 20, loss: 4.6250337191990445
Epoch 0, step 25, loss: 4.625032406586867
Epoch 0, step 30, loss: 4.625058743261522
Evaluating 0




****************************************Epoch 0 acc: 0.9920635223388672
Epoch 1, step 0, loss: 4.624999523162842
Epoch 1, step 5, loss: 4.6249903837839765
Epoch 1, step 10, loss: 4.625003251162442
Epoch 1, step 15, loss: 4.625003427267075
Epoch 1, step 20, loss: 4.625004064469111
Epoch 1, step 25, loss: 4.625011645830595
Epoch 1, step 30, loss: 4.6250187812312955
Epoch 2, step 0, loss: 4.624924182891846
Epoch 2, step 5, loss: 4.624948183695476
Epoch 2, step 10, loss: 4.624965060840953
Epoch 2, step 15, loss: 4.624975860118866
Epoch 2, step 20, loss: 4.624977951958066
Epoch 2, step 25, loss: 4.62500262260437
Epoch 2, step 30, loss: 4.6250205962888655
Evaluating 2


KeyboardInterrupt: 

In [36]:
sum([p.sum() for p in ebm.parameters()])

tensor(-56383.9961, device='cuda:0', grad_fn=<AddBackward0>)

In [19]:
sum([1 for p in ebm.parameters()])

107

In [21]:
for name, p in ebm.named_parameters():
    print(name)

conv1.weight
conv1.bias
bn1.weight
bn1.bias
res_blocks.0.net.0.conv1.weight
res_blocks.0.net.0.conv1.bias
res_blocks.0.net.0.bn1.weight
res_blocks.0.net.0.bn1.bias
res_blocks.0.net.0.conv2.weight
res_blocks.0.net.0.conv2.bias
res_blocks.0.net.0.bn2.weight
res_blocks.0.net.0.bn2.bias
res_blocks.0.net.1.conv1.weight
res_blocks.0.net.1.conv1.bias
res_blocks.0.net.1.bn1.weight
res_blocks.0.net.1.bn1.bias
res_blocks.0.net.1.conv2.weight
res_blocks.0.net.1.conv2.bias
res_blocks.0.net.1.bn2.weight
res_blocks.0.net.1.bn2.bias
res_blocks.1.net.0.conv1.weight
res_blocks.1.net.0.conv1.bias
res_blocks.1.net.0.bn1.weight
res_blocks.1.net.0.bn1.bias
res_blocks.1.net.0.conv2.weight
res_blocks.1.net.0.conv2.bias
res_blocks.1.net.0.bn2.weight
res_blocks.1.net.0.bn2.bias
res_blocks.1.net.0.downsample.weight
res_blocks.1.net.1.conv1.weight
res_blocks.1.net.1.conv1.bias
res_blocks.1.net.1.bn1.weight
res_blocks.1.net.1.bn1.bias
res_blocks.1.net.1.conv2.weight
res_blocks.1.net.1.conv2.bias
res_blocks.1.net.

In [33]:
from torchvision.models import resnet34
resnet = resnet34()

In [34]:
sum([1 for p in resnet.parameters()])

110

In [35]:
for name, p in resnet.named_parameters():
    print(name)

conv1.weight
bn1.weight
bn1.bias
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.2.conv1.weight
layer1.2.bn1.weight
layer1.2.bn1.bias
layer1.2.conv2.weight
layer1.2.bn2.weight
layer1.2.bn2.bias
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.weight
layer2.1.bn1.bias
layer2.1.conv2.weight
layer2.1.bn2.weight
layer2.1.bn2.bias
layer2.2.conv1.weight
layer2.2.bn1.weight
layer2.2.bn1.bias
layer2.2.conv2.weight
layer2.2.bn2.weight
layer2.2.bn2.bias
layer2.3.conv1.weight
layer2.3.bn1.weight
layer2.3.bn1.bias
layer2.3.conv2.weight
layer2.3.bn2.weight
layer2.3.bn2.bias
layer3.0.conv1.weight
layer3.0.bn1.weight
