In [96]:
import os
import sys
from glob import glob
from PIL import Image

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import models
from torch.utils.data import DataLoader, Dataset

sys.path.insert(0, '../')
from dataset import TrainDataset
from transform_settings import configure_transform
from utils import load_json

In [105]:
ROOT = '../preprocessed/train'
TASK = 'all'
META = '../preprocessed/metadata.json'
transform = configure_transform('train', 'base')
data = TrainDataset(ROOT, transform, TASK, META)
loader = DataLoader(data, batch_size=4)

In [108]:
for imgs, labels in loader:
    break

In [118]:
for param in backbone.parameters():
    param.requires_grad = False

In [123]:
from collections import OrderedDict

class THAResNet(nn.Module): # Three-headed attention EfficientNEt
    def __init__(self):
        super(THAResNet, self).__init__()
        self.backbone = models.resnet18(pretrained=True)
        
        self.linear_mask = nn.Linear(1000, 3)
        self.linear_ageg = nn.Linear(1000, 3)
        self.linear_gender = nn.Linear(1000, 2)

        self._freeze()
    
    def forward(self, x):
        x = self.backbone(x)

        output_mask = self.linear_mask(x)
        output_ageg = self.linear_ageg(x)
        output_gender = self.linear_gender(x)
        return output_mask, output_ageg, output_gender

    def _freeze(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

In [124]:
model = THAResNet()

In [125]:
optim_mask = optim.Adam(model.linear_mask.parameters(), lr=.005)
optim_ageg = optim.Adam(model.linear_ageg.parameters(), lr=.005)
optim_gender = optim.Adam(model.linear_gender.parameters(), lr=.005)

criterion = nn.CrossEntropyLoss()

In [126]:
output_mask, output_ageg, output_gender = model(imgs)

In [127]:
loss_mask = criterion(output_mask, labels['mask'])
loss_ageg = criterion(output_ageg, labels['ageg'])
loss_gender = criterion(output_gender, labels['gender'])

In [117]:
loss_mask /= 2
loss_ageg /= 2
loss_gender /= 2

tensor(0.5958, grad_fn=<DivBackward0>)

In [6]:
resnet.fc = nn.Identity()

In [8]:
for img, label in dataloader:
    break

In [21]:
post_resnet = resnet(img)
post_resnet.size()

torch.Size([16, 2048])

In [24]:
MASK_CLASSES = 3
SEX_CLASSES = 2
AGEG_CLASSES = 3

In [64]:
from collections import OrderedDict
def get_linear_dropout(in_features, out_features, p=.5):
    layer = nn.Sequential(OrderedDict([
        ('linear', nn.Linear(in_features=in_features, out_features=out_features)), 
        ('dropout', nn.Dropout(p=0.5))]))
    return layer

In [65]:
fc_mask = get_linear_dropout(in_features=2048, out_features=3)
fc_sex = get_linear_dropout(in_features=2048, out_features=2)
fc_ageg = get_linear_dropout(in_features=2048, out_features=3)

In [67]:
fc_layer = nn.ModuleList([fc_mask, fc_sex, fc_ageg])

In [68]:
outputs = []
for fc in fc_layer:
    outputs.append(fc(post_resnet))
    

In [53]:
label = torch.LongTensor([0 for _ in range(16)])

In [54]:
F.cross_entropy(outputs[0], label)

tensor(0.9253, grad_fn=<NllLossBackward>)

In [None]:
nn.ModuleList()

In [8]:
for name, parameter in resnet.named_parameters():
    print(name)

conv1.weight
bn1.weight
bn1.bias
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.conv3.weight
layer1.0.bn3.weight
layer1.0.bn3.bias
layer1.0.downsample.0.weight
layer1.0.downsample.1.weight
layer1.0.downsample.1.bias
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.conv3.weight
layer1.1.bn3.weight
layer1.1.bn3.bias
layer1.2.conv1.weight
layer1.2.bn1.weight
layer1.2.bn1.bias
layer1.2.conv2.weight
layer1.2.bn2.weight
layer1.2.bn2.bias
layer1.2.conv3.weight
layer1.2.bn3.weight
layer1.2.bn3.bias
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.conv3.weight
layer2.0.bn3.weight
layer2.0.bn3.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.weight
layer2.1.bn1.bias
layer2.1.conv2.we