In [1]:
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import csv
import tqdm.notebook as tq
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision.utils import save_image

import matplotlib.pyplot as plt

In [2]:
H, W = 300, 500

In [3]:
#  !ls -la "/content/data/"

In [4]:
# transform = torchvision.transforms.Resize((H,W))
transform = transforms.Compose([
    transforms.Resize((H,W)),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
train = torchvision.datasets.ImageFolder("data", transform)
# train = torchvision.datasets.ImageFolder("data", transform)

In [5]:
train_loader = DataLoader(
    train,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

In [6]:
# get waterbirds; returns dict birds
# 1,001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg,1,2,1,/o/ocean/00002178.jpg
# img_id,img_filename,y,split,place,place_filename
def get_place():
    '''
    Gets a dictionary mapping file names to environment ID

    Returns:    
        dict with values (string) filename mapped to
        keys (int) for environment ID*

        *1 corresponds to water environment (ocean) and
        *0 corresponds to land environment

    Example entry:
        # print(get_place())
        # '001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg': 1, 
        # '001.Black_footed_Albatross/Black_Footed_Albatross_0009_34.jpg': 1,	
    '''
    birds = {}
    with open('./data/metadata.csv') as f:
        reader = csv.DictReader(f)
        for row in reader:
            fname = row['img_filename'].split('/')[-1]
            birds[fname] = int(row['place'])
    return birds


In [7]:
birds = get_place()

In [8]:
class BasicConvNet(nn.Module):
    '''
    A basic neural network. Not particularly based on anything,
    2 conv layers and last layer is FC.
    '''

    def __init__(self, num_classes, num_envs=2):
        super(BasicConvNet, self).__init__()
        self.conv2d1 = nn.Conv2d(3, 2, 3, 1, 1)
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2d2 = nn.Conv2d(2, 4, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(12544, num_classes * num_envs)
        self._weight_init()

    def _weight_init(self):
        for layer in self.children():
            if (layer.__class__.__name__ == 'Linear' \
                or layer.__class__.__name__ == 'Conv2d'):
                nn.init.kaiming_uniform_(layer.weight.data,0.2)
                nn.init.constant_(layer.bias.data,0)
            

    def forward(self, x):
        x = self.conv2d1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2d2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

## Phase 1
Train featurizers and classifiers together for all environments.
Both featurizer and classifier should be mutable (are being trained).

Featurizer should be same for samples from all environments,
classifier should be different for samples from different environments.

In [9]:
lr = 0.01
n_epochs = 50
birds = get_place()
net = BasicConvNet(num_classes=2)
optimizer = optim.Adam(net.parameters(), lr=lr)
optimizer.zero_grad()
criterion = nn.CrossEntropyLoss()

# need 2 different loss functions. 
# such that they have gradient = 0 on half the weights
# for the irrelevant environment

for epoch in tq.tqdm(range(n_epochs)):
    for batch_idx, data in  enumerate(train_loader):
        inputs, labels = data

        # inputs are currently sorted
        sample_fname, _ = train_loader.dataset.samples[batch_idx]
        fname = sample_fname.split('/')[-1]
        # filename example: Black_Footed_Albatross_0009_34.jpg

        if birds[fname] == 0:
            # Train land-env birds
            optimizer.zero_grad()
            outputs = net(inputs)[:,:2] # get the first num_classes entries      
            land_loss = criterion(outputs, labels.type(torch.LongTensor)) 
            land_loss.backward()
            optimizer.step()
            
        elif birds[fname] == 1:
            # Train water-env birds
            optimizer.zero_grad()
            outputs = net(inputs)[:,2:] # get the last num_class entries
            water_loss = criterion(outputs, labels.type(torch.LongTensor)) 
            water_loss.backward()
            optimizer.step()


    print(f"E: {epoch}; Land Loss: {land_loss.item()}; Water Loss: {water_loss.item()}", birds[fname])


  0%|          | 0/50 [00:00<?, ?it/s]

E: 0; Land Loss: 0.7006399631500244; Water Loss: 0.613723874092102 1
E: 1; Land Loss: 0.7138070464134216; Water Loss: 0.6097952127456665 1
E: 2; Land Loss: 0.7397164106369019; Water Loss: 0.6049914956092834 1
E: 3; Land Loss: 0.7641109228134155; Water Loss: 0.6022160053253174 1
E: 4; Land Loss: 0.7856600284576416; Water Loss: 0.600159227848053 1
E: 5; Land Loss: 0.8044002056121826; Water Loss: 0.5986003279685974 1
E: 6; Land Loss: 0.820556104183197; Water Loss: 0.5974159240722656 1
E: 7; Land Loss: 0.8344010710716248; Water Loss: 0.596515953540802 1
E: 8; Land Loss: 0.8462138772010803; Water Loss: 0.5958320498466492 1
E: 9; Land Loss: 0.8562589883804321; Water Loss: 0.595312774181366 1
E: 10; Land Loss: 0.8647788763046265; Water Loss: 0.5949188470840454 1
E: 11; Land Loss: 0.8719899654388428; Water Loss: 0.5946202278137207 1
E: 12; Land Loss: 0.8780828714370728; Water Loss: 0.5943942666053772 1
E: 13; Land Loss: 0.8832239508628845; Water Loss: 0.5942235589027405 1
E: 14; Land Loss: 0.8

In [10]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

# # Print optimizer's state_dict
# print("Optimizer's state_dict:")
# for var_name in optimizer.state_dict():
#     print(var_name, "\t", optimizer.state_dict()[var_name])

torch.save(net.state_dict(), "phase1")

Model's state_dict:
conv2d1.weight 	 torch.Size([2, 3, 3, 3])
conv2d1.bias 	 torch.Size([2])
conv2d2.weight 	 torch.Size([4, 2, 3, 3])
conv2d2.bias 	 torch.Size([4])
fc.weight 	 torch.Size([4, 12544])
fc.bias 	 torch.Size([4])


## Phase 2

Train classifiers for environments only, freezing featurizers

In [11]:
for layer in net.children():
    # unless it's the last linear layer, freeze everything.
    for p in layer.parameters():
        p.requires_grad = False
    if layer.__class__.__name__ == 'Linear':
        # if it's one of the first environment weights, let it train
        for i,p in enumerate(layer.parameters()):
            if i < 2: p.requires_grad=True

for layer in net.children():
    print(layer.__class__.__name__)
    for p in layer.parameters():
        print("is frozen: ", str(not p.requires_grad))

Conv2d
is frozen:  True
is frozen:  True
ReLU
MaxPool2d
Conv2d
is frozen:  True
is frozen:  True
MaxPool2d
Flatten
Linear
is frozen:  False
is frozen:  False


In [12]:
l_optim_frozen = optim.SGD(
    filter(lambda p: p.requires_grad, net.parameters()), lr=lr)

for epoch in tq.tqdm(range(n_epochs)):
    for batch_idx, data in enumerate(train_loader):
        inputs, labels = data
        l_optim_frozen.zero_grad()
        outputs = net(inputs)[:,:2] 
        # we're only working with the first weights now 
        loss = criterion(outputs, labels.type(torch.LongTensor)) 
        loss.backward()
        l_optim_frozen.step()


    print(f"E: {epoch}; Land Loss: {land_loss.item()}; Water Loss: {water_loss.item()}")

  0%|          | 0/50 [00:00<?, ?it/s]

E: 0; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 1; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 2; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 3; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 4; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 5; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 6; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 7; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 8; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 9; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 10; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 11; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 12; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 13; Land Loss: 0.9101296067237854; Water Loss: 0.5935475826263428
E: 14; Land Loss: 0.9101296067237854; Water 