In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
import pandas as pd
from args import Args
import copy
import numpy as np
from collections import defaultdict
import os
from PIL import Image
from tqdm import tqdm
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset
from google.cloud import storage
import random

storage_client = storage.Client("leo_font")
bucket = storage_client.bucket("leo_font")

In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = Args()

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, font, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.args = args
        self.font = font
        self.files = [f"{root_dir}/{f}" for f in os.listdir(root_dir) if (".png" in f)&(font in f)]

    def __len__(self):
        # return int(len(self.files)/100)
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        image = Image.open(path)

        if self.transform:
            image = self.transform(image)

        return image

In [4]:
def save_model(state_dict, save_path):
    blob = bucket.blob(save_path)
    with blob.open("wb", ignore_flush=True) as f:
        torch.save(state_dict, f)

In [5]:
fonts = np.unique([f.split("__")[0] for f in os.listdir(f"{args.datapath}/seen")])

In [6]:
# Load data
dataset_list = [CustomDataset(root_dir=f'{args.datapath}/seen', font=f, transform=transform) for f in fonts]

In [7]:
def dataload(dataset_list, n_classes, n_per_class):
    ds_ids = np.random.choice(np.arange(len(dataset_list)),n_classes)
    supports = []
    queries = []
    for i in ds_ids:
        for _ in range(n_per_class):
            supports.append(dataset_list[i][np.random.randint(len(dataset_list[i]))])
            queries.append(dataset_list[i][np.random.randint(len(dataset_list[i]))])
    return torch.stack(supports), torch.stack(queries)

def euclidean_dist(x, y):
    '''
    Compute euclidean distance between two tensors
    '''
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    if d != y.size(1):
        raise Exception

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)

In [8]:
class ModifiedEfficientNet(nn.Module):
    def __init__(self):
        super(ModifiedEfficientNet, self).__init__()
        self.effnet = efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1)

    def forward(self, x):
        x = self.effnet(x)
        return torch.tanh(x)

In [9]:
model = ModifiedEfficientNet()
model = model.to(device)

In [10]:
n_classes = 4
n_per_class = 4
model_name = 'effproto'

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
label = torch.eye(n_classes)[np.repeat(np.arange(n_classes),n_per_class)].to(device)
trailing_loss = 1

# Training loop
num_epochs = 1000000
pbar = tqdm(range(num_epochs))
for epoch in pbar:
    model.train()
    optimizer.zero_grad()
    # get data
    support_x, query_x = dataload(dataset_list, n_classes, n_per_class)
    support_x = support_x.to(device)
    query_x = query_x.to(device)
    # get the needed
    support_y = model(support_x)
    prototypes = torch.stack([support_y[idx].mean(0) for idx in np.split(np.arange(int(n_classes*n_per_class)),n_per_class)])
    query_y = model(query_x)
    # get loss
    dist = euclidean_dist(query_y, prototypes)
    loss = criterion(-torch.log(dist), label)
    # do the needed
    loss.backward()
    optimizer.step()
    
    trailing_loss = 0.95*trailing_loss + 0.05*loss.detach().cpu().numpy()
    pbar.set_postfix(trailing_loss=trailing_loss)
    if epoch % 1000 == 0:
        save_model(model.state_dict(),f"{args.savepath}/{model_name}_{epoch}.pth")

  0%|          | 199/1000000 [02:38<196:29:49,  1.41it/s, trailing_loss=0.643]

In [12]:
loss

tensor(0.9886, device='cuda:0', grad_fn=<DivBackward1>)