In [None]:
%load_ext autoreload 
%autoreload 2

import os
import random
import numpy as np
import scipy.linalg as sl
from PIL import Image
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns
from IPython import display

import torch
from torch import nn, distributions as dist, autograd
from torch.func import jacfwd
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Resize, CenterCrop, RandomHorizontalFlip, RandomVerticalFlip, ToTensor, Normalize
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
# torch.set_default_device("cuda")
torch.set_default_dtype(torch.float32)
plt.style.use('seaborn-v0_8')

In [None]:

DATASET_PATH = "/mnt/dl/datasets/Oxford102FlowersSplits/"
os.environ["KERAS_BACKEND"] = "tensorflow"
LABELS = {i: k.strip() for i, k in enumerate(open(os.path.join(DATASET_PATH, "names.txt")))}
img_size = 112
batch_size = 32
num_classes = len(LABELS)
patch_size = 16
num_patches = img_size ** 2 / patch_size **2

In [None]:
class FlowerDataset(Dataset):
    def __init__(self, path, split, cache=True, transforms=None):
        super().__init__()
        self.load_data(path, split)
        self.samples = dict()
        self.transforms = transforms
        
    def load_data(self, path, split):
        path = os.path.join(path, split, )
        img_files = os.listdir(os.path.join(path, "jpeg"))
        img_files = sorted(img_files, key=lambda x: int(x.replace(".jpeg", "")))
        img_files = list(img_files)
        
        labels = list(open(os.path.join(path, "label", "label.txt"),))
        self.labels = [int(l.strip()) for l in labels]
        
        self.img_files = [os.path.join(path, "jpeg", name) for name in img_files]
    
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, index):
        if index not in self.samples:
            self.load_sample(index)
        sample = self.samples[index]
        if self.transforms is not None:
            sample = self.transforms(sample)

        return (sample, self.labels[index])
        
    def load_sample(self, idx):
        img = Image.open(self.img_files[idx])
        img = np.array(img).astype(np.float32)
        self.samples[idx] = img
        return True


In [None]:
train_ds = FlowerDataset(DATASET_PATH, "train", transforms=Compose([
    ToTensor(),
    Resize((img_size, img_size)),
    RandomHorizontalFlip(0.1),
    RandomVerticalFlip(0.),
    Normalize(0., 255.0)
]))
val_ds = FlowerDataset(DATASET_PATH, "validation", transforms=Compose([
    ToTensor(),
    Resize((img_size, img_size)),
    Normalize(0., 255.0)
    
]))

test_ds = FlowerDataset(DATASET_PATH, "test", transforms=Compose([
    ToTensor(),
    Resize((img_size, img_size)),
    Normalize(0., 255.0)
    
]))

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=True, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=16)

In [None]:
from torchvision.models import resnet34
ebm = resnet34(num_classes)
ebm.fc = nn.Linear(512, num_classes)


In [None]:
std = torch.tensor([torch.std(p) for p in ebm.parameters()])
mean = torch.tensor([torch.mean(p) for p in ebm.parameters()])
params = torch.column_stack([std, mean])

In [None]:
params.sort(0).values

In [None]:
params.mean(0)

In [None]:
ebm.cuda()

In [None]:
print(sum([p.sum().item() for p in ebm.parameters()]))

In [None]:
with torch.no_grad():
    for p in ebm.parameters():
        if len(p.size()) == 1:
            nn.init.zeros_(p)
            continue
        # nn.init.uniform_(p, -5e-3, 5e-3)
        nn.init.xavier_normal_(p)
print(sum([p.sum().item() for p in ebm.parameters()]))

## Pretraining

In [None]:
class ClassificationTrainer:
    
    def __init__(self, model, train_loader, val_loader=None, epochs=1, eval_epochs=0, savepath=None):
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss(reduction="mean")
        self.train_loader = train_loader 
        self.val_loader = val_loader
        self.epochs = epochs
        self.eval_epochs = eval_epochs
        self.savepath = savepath
        self.eval_savepath = os.path.join(self.savepath, "eval")
        self.model_savepath = os.path.join(self.savepath, "model")
        
        os.makedirs(self.model_savepath, exist_ok=True)
        os.makedirs(self.eval_savepath, exist_ok=True)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        
        
    def train(self,):
        self.train_losses = []
        self.acc = []
        best_loss = 0.
        for i in range(self.epochs):
            ep_losses = self.run_epoch(i)
            self.train_losses.extend(ep_losses)
            if self.eval_epochs > 0 and i % self.eval_epochs == 0:
                acc = self.eval_epoch(i)
                if acc > best_loss:
                    best_loss = acc
                    self.save_model(fname="best_model", epoch=i)
                print("**" * 20 + f"Epoch {i} acc: {acc}")
                self.acc.append(acc.item())
        print("Succesfully trained...")
        self.save_model(f"last_model", self.epochs)
        return True
    
    def save_model(self, fname, epoch=0):
        torch.save({"model": self.model.state_dict(),
                    "optimizers": self.optimizer.state_dict(),
                    "losses": self.train_losses,
                    "epoch": epoch
                    
            }, os.path.join(self.model_savepath, fname))

    def run_epoch(self, epoch):
        losses = []
        self.model.train()
        for j, (img, label) in enumerate(self.train_loader):
            img, label = img.cuda(), label.cuda()
            pred = self.model(img)
            loss = self.loss_fn(pred, label)
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.)
            self.optimizer.step()
            losses.append(loss.item())
            # print(sum([p.sum() for p in ebm.parameters()]), img.mean(), label.float().mean())
            if j % 5 == 0:
                print(f"Epoch {epoch}, step {j}, loss: {np.mean(losses)}")
                
        return losses
    
    def eval_epoch(self, epoch):
        savepath = os.path.join(self.eval_savepath, f"{epoch:05d}")
        os.makedirs(savepath, exist_ok=True)
        self.model.eval()
        print(f"Evaluating {epoch}")
        ep_acc = []
        with torch.no_grad():
            for k, (img, label) in enumerate(self.val_loader):
                img, label = img.cuda(), label.cuda()
                pred = self.model(img)
                acc = self.get_accuracy(pred, label)
                ep_acc.extend(acc)
        
        self.model.train()
        
        return  torch.stack(ep_acc).mean() * 100.
    

    def get_accuracy(self, input, target):
        inp_argmax = input.argmax(axis=1)
        acc = inp_argmax == target
        acc = acc.to(torch.float32)
        
        return acc
            

In [None]:
flower_classifier = ClassificationTrainer(ebm, train_loader=train_loader, val_loader=val_loader,
                                      epochs=1000, eval_epochs=20, savepath="/mnt/dl/generation/ebm/classification")

In [None]:
sum([p.sum() for p in ebm.parameters()])

In [None]:
flower_classifier.train()

In [None]:
sum([p.sum() for p in ebm.parameters()])