# Feature Extraction

### 1. Feature Extraction
**Model to identify causal relationships**: 
feature extraction network (ResNet18) trained on ImageNet and 
classifier network (two 512-unit hidden layers) trained on Pascal VOC 2012

In [1]:
import numpy as np
import torch
from torch import nn
from torchvision.models import resnet18

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def feature_extractor():
    model_ft = resnet18(pretrained=True)
    # finetune
    for param in model_ft.parameters():
        param.requires_grad = False
    # modify classifier
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Flatten()
    # features = model_ft._modules.get('avgpool')
    return model_ft

In [10]:
class Classifier(nn.Module):
    def __init__(self, in_features=512, hidden_dim=512):
        super(Classifier, self).__init__()
        self.layer1 = nn.Linear(in_features, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, 20)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.classifier(x)
        # softmax
        return x

In [11]:
device = 'cpu'

resnet = feature_extractor().to(device)
classifier = Classifier().to(device)

img = torch.rand(size=[1,3,224,224])

In [12]:
resnet.eval()

features = resnet(img)
print(features.shape)
object_logodds = classifier(features)
print(object_logodds.shape)

torch.Size([1, 512])
torch.Size([1, 20])


### 2. ResNet Classifier Training

**Dataset**: 
subset of 99,309 MSCOCO images belonging to 20 Pascal object categories. 
resize (224x224, 짧은 쪽이 224 pixel로 rescale하고 224x224 centercrop).

In [None]:
categories = [
    'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 
    'dining table', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'television'
]

In [None]:
# MSCOCO image (bbox + label)
class COCODataset:
    def __init__(self):
        pass