In [24]:
import os
import sys
from glob import glob
from PIL import Image

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import models
from torch.utils.data import DataLoader, Dataset

sys.path.insert(0, '../')
from dataset import TrainDataset
from augmentation import configure_transform
from utils import load_json, set_seed
from model import VanillaEfficientNet

In [30]:
SAVE_PATH = '../saved_models'

load_gender = 'VanillaEfficientNet_taskgender_epoch02_lr0.005_transformbase_optimadam_loss0.0001_eval0.9658_seed42.pth'

load_mask = 'VanillaEfficientNet_taskmask_epoch04_lr0.005_transformbase_optimadam_loss0.0000_eval0.9909_seed42.pth'

load_ageg = 'VanillaEfficientNet_taskageg_epoch04_lr0.005_transformbase_optimadam_loss0.0002_eval0.9118_seed42.pth'

In [35]:
linear_gender = VanillaEfficientNet(n_class=2)
linear_gender.load_state_dict(torch.load(os.path.join(SAVE_PATH, load_gender)))
linear_gender = linear_gender.linear

linear_mask = VanillaEfficientNet(n_class=3)
linear_mask.load_state_dict(torch.load(os.path.join(SAVE_PATH, load_mask)))
linear_mask = linear_mask.linear

linear_ageg = VanillaEfficientNet(n_class=3)
linear_ageg.load_state_dict(torch.load(os.path.join(SAVE_PATH, load_ageg)))
linear_ageg = linear_ageg.linear

Loaded pretrained weights for efficientnet-b3
Loaded pretrained weights for efficientnet-b3
Loaded pretrained weights for efficientnet-b3


In [2]:
ROOT = '../preprocessed/train'
TASK = 'all'
META = '../preprocessed/metadata.json'
set_seed(42)
transform = configure_transform('train', 'base')
data = TrainDataset(ROOT, transform, TASK, META)
loader = DataLoader(data, batch_size=4)

Seed set as 42


In [3]:
for imgs, labels in loader:
    break

In [4]:
from collections import OrderedDict

class THAResNet_MK1(nn.Module): # Three-headed attention EfficientNEt
    def __init__(self):
        super(THAResNet_MK1, self).__init__()
        self.backbone = models.resnet18(pretrained=True)
        self.linear_mask = nn.Linear(1000, 3)
        self.linear_ageg = nn.Linear(1000, 3)
        self.linear_gender = nn.Linear(1000, 2)
        self.linear_main = nn.Linear(1000, 18)

    def forward(self, x):
        x = self.backbone(x)
        output_mask = self.linear_mask(x)
        output_ageg = self.linear_ageg(x)
        output_gender = self.linear_gender(x)
        output_main = self.linear_main(x)
        return output_mask, output_ageg, output_gender, output_main

    def _freeze(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

In [5]:
model = THAResNet_MK1()

In [17]:
optim_mask = optim.Adam(model.linear_mask.parameters(), lr=.005)
optim_gender = optim.Adam(model.linear_gender.parameters(), lr=.005)
optim_ageg = optim.Adam(model.linear_ageg.parameters(), lr=.005)
optim_main = optim.Adam(model.parameters(), lr=.005)

optim_mask_interaction = optim.Adam(list(model.backbone.parameters()) + list(model.linear_mask.parameters()), lr=.005)
optim_ageg_interaction = optim.Adam(list(model.backbone.parameters()) + list(model.linear_ageg.parameters()), lr=.005)
optim_gender_interaction = optim.Adam(list(model.backbone.parameters()) + list(model.linear_gender.parameters()), lr=.005)

criterion = nn.CrossEntropyLoss()

In [18]:
output_mask, output_ageg, output_gender, output_main = model(imgs)

In [19]:
loss_mask = criterion(output_mask, labels['mask'])
loss_ageg = criterion(output_ageg, labels['ageg'])
loss_gender = criterion(output_gender, labels['gender'])
loss_main = criterion(output_main, labels['main'])

loss_mask *= 0.375
loss_ageg *= 0.375
loss_gender *= 0.375
loss_main *= .5

loss_mask_interaction = criterion(output_mask, labels['mask'])
loss_ageg_interaction = criterion(output_ageg, labels['ageg'])
loss_gender_interaction = criterion(output_gender, labels['gender'])

loss_mask_interaction *= 0.125
loss_ageg_interaction *= 0.125
loss_gender_interaction *= 0.125

In [20]:
loss_mask_interaction.backward(retain_graph=True)
loss_ageg_interaction.backward(retain_graph=True)
loss_gender_interaction.backward(retain_graph=True)

loss_mask.backward(retain_graph=True)
loss_ageg.backward(retain_graph=True)
loss_gender.backward(retain_graph=True)
loss_main.backward()

In [21]:
optim_mask.step()
optim_gender.step()
optim_ageg.step()
optim_main.step()
optim_mask_interaction.step()
optim_ageg_interaction.step()
optim_gender_interaction.step()