In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
import pandas as pd
from args import Args
import copy
import numpy as np
from collections import defaultdict
import os
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from google.cloud import storage

storage_client = storage.Client("leo_font")
bucket = storage_client.bucket("leo_font")

In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = Args()

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, labeler, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.args = args
        self.labeler = labeler
        self.files = [f"{root_dir}/{f}" for f in os.listdir(root_dir) if ".png" in f]

    def __len__(self):
        # return int(len(self.files)/100)
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        letter = path.split("__")[-1].replace(".png","")
        cho, jung, jong = self.labeler[letter]
        cho = np.eye(19)[cho]
        jung = np.eye(21)[jung]
        jong = np.eye(28)[jong]
        image = Image.open(path)

        if self.transform:
            image = self.transform(image)

        return image, torch.from_numpy(cho), torch.from_numpy(jung), torch.from_numpy(jong)

In [4]:
class CustomPerformancer:
    def __init__(self):
        self.total = 0
        self.total_correct = 0
        self.cho_correct = 0
        self.jung_correct = 0
        self.jong_correct = 0
        self.histories = defaultdict(list)
    
    def reset(self):
        self.total = 0
        self.total_correct = 0
        self.cho_correct = 0
        self.jung_correct = 0
        self.jong_correct = 0
    
    def take(self, outs, labels):
        choout, jungout, jongout = outs
        chol, jungl, jongl = labels
        
        _, chopred = choout.max(1)
        _, choreal = chol.max(1)
        chocorr = chopred.eq(choreal)
        
        _, jungpred = jungout.max(1)
        _, jungreal = jungl.max(1)
        jungcorr = jungpred.eq(jungreal)
        
        _, jongpred = jongout.max(1)
        _, jongreal = jongl.max(1)
        jongcorr = jongpred.eq(jongreal)
        
        allcorr = chocorr & jungcorr & jongcorr

        self.total += len(allcorr)
        self.total_correct += sum(allcorr).item()
        self.cho_correct += sum(chocorr).item()
        self.jung_correct += sum(jungcorr).item()
        self.jong_correct += sum(jongcorr).item()
        
    def accuracies(self):
        return {
            "total_accuracy": self.total_correct/self.total,
            "cho_accuracy": self.cho_correct/self.total,
            "jung_accuracy": self.jung_correct/self.total,
            "jong_accuracy": self.jong_correct/self.total,
            "count": self.total,
        }
    
    def save_history(self, name):
        self.histories[name].append(self.accuracies())

In [5]:
def save_model(state_dict, save_path):
    blob = bucket.blob(save_path)
    with blob.open("wb", ignore_flush=True) as f:
        torch.save(state_dict, f)
        
def save_history(hist, save_path):
    blob = bucket.blob(save_path)
    with blob.open("wb", ignore_flush=True) as f:
        pd.DataFrame().from_dict(hist).to_csv(f)
    

In [6]:
# Load data
train_dataset = CustomDataset(root_dir=f'{args.datapath}/seen', labeler=args.labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, num_workers=10)
test_dataset = CustomDataset(root_dir=f'{args.datapath}/unseen', labeler=args.labels, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False)

In [7]:
class ModifiedEfficientNet(nn.Module):
    def __init__(self):
        super(ModifiedEfficientNet, self).__init__()
        self.effnet = efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1)
        in_features = 1000
        self.cho_fc = nn.Linear(in_features, 19)
        self.jung_fc = nn.Linear(in_features, 21)
        self.jong_fc = nn.Linear(in_features, 28)

    def forward(self, x):
        x = self.effnet(x)
        cho_out = self.cho_fc(x)
        jung_out = self.jung_fc(x)
        jong_out = self.jong_fc(x)
        return cho_out, jung_out, jong_out
    
    def set_feature_extractor_trainable(self, trainable):
        for param in self.effnet.parameters():
            param.requires_grad = trainable

In [8]:
model = ModifiedEfficientNet()
model = model.to(device)

In [9]:
len(train_loader)

27453

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
perf = CustomPerformancer()

# Training loop
num_epochs = 100
pbar = tqdm(total= num_epochs*(len(train_loader)+len(test_loader)))
# pbar = tqdm(total= num_epochs*len(train_loader))
for epoch in range(num_epochs):
    
    if epoch == 0:
        save_model(model.state_dict(),f"{args.savepath}/effnet_{epoch}.pth")
        save_history(perf.histories['train'], f"{args.savepath}/train_history.csv")
        save_history(perf.histories['test'], f"{args.savepath}/test_history.csv")
    
    model.train()
    if epoch == 0:
        model.set_feature_extractor_trainable(False)
    else:
        model.set_feature_extractor_trainable(True)
    perf.reset()
    for inputs, cho, jung, jong in train_loader:
        inputs, cho, jung, jong = inputs.to(device), cho.to(device), jung.to(device), jong.to(device)
        optimizer.zero_grad()
        cho_out, jung_out, jong_out = model(inputs)
        loss = criterion(cho_out, cho)
        loss += criterion(jung_out, jung)
        loss += criterion(jong_out, jong)
        loss.backward()
        optimizer.step()
        perf.take([cho_out.cpu(), jung_out.cpu(), jong_out.cpu()],[cho.cpu(),jung.cpu(),jong.cpu()])
        pbar.update(1)
        pbar.set_postfix(mode='train', acc=f"{perf.accuracies()['total_accuracy']:.4f}")

    perf.save_history('train')
    perf.reset()
    model.eval()
    for inputs, cho, jung, jong in test_loader:
        inputs, cho, jung, jong = inputs.to(device), cho.to(device), jung.to(device), jong.to(device)
        cho_out, jung_out, jong_out = model(inputs)
        perf.take([cho_out.cpu(), jung_out.cpu(), jong_out.cpu()],[cho.cpu(),jung.cpu(),jong.cpu()])
        pbar.update(1)
        pbar.set_postfix(mode='test', acc=f"{perf.accuracies()['total_accuracy']:.4f}")
    perf.save_history('test')
    
    save_model(model.state_dict(),f"{args.savepath}/effnet_{epoch}.pth")
    save_history(perf.histories['train'], f"{args.savepath}/train_history.csv")
    save_history(perf.histories['test'], f"{args.savepath}/test_history.csv")

  1%|          | 35040/2952000 [1:56:13<429:32:05,  1.89it/s, acc=0.9033, mode=train] 