In [194]:
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import csv
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision.utils import save_image

import matplotlib.pyplot as plt

In [195]:
H, W = 300, 500

In [196]:
# transform = torchvision.transforms.Resize((H,W))
transform = transforms.Compose([
    transforms.Resize((H,W)),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
train = torchvision.datasets.ImageFolder("data", transform)
# train = torchvision.datasets.ImageFolder("data", transform)

In [197]:
train_loader = DataLoader(
    train,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

In [198]:
# get waterbirds; returns dict birds
# 1,001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg,1,2,1,/o/ocean/00002178.jpg
# img_id,img_filename,y,split,place,place_filename
def get_place():
    '''
    returns:    birds, dict with keys 
                    '0' for land
                    '1' for water
    '''
    birds = {}
    with open('./waterbirds/metadata.csv') as f:
        reader = csv.DictReader(f)
        for row in reader:
            fname = row['img_filename'].split('/')[-1]
            birds[fname] = int(row['place'])
    return birds
			
# print(get_place())
# '001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg': 1, 
# '001.Black_footed_Albatross/Black_Footed_Albatross_0009_34.jpg': 1,	


In [199]:
birds = get_place()

In [200]:
for batch_idx, (images, labels) in enumerate(train_loader):
      # print("batch", batch_idx)
      # for i, (images, labels) in enumerate(test_loader, 0):
      # outputs = model(images)
      # _, predicted = torch.max(outputs.data, 1)
      fname, _ = train_loader.dataset.samples[batch_idx]
      # print(sample_fname)
      fname = fname.split('/')[-1]
      # print(fname)
      # TODO: handle batch size > 1

In [201]:
class BasicConvNet(nn.Module):
    """This is a small neural network."""

    def __init__(self, num_classes):
        super(BasicConvNet, self).__init__()
        self.conv2d1 = nn.Conv2d(3, 2, 3, 1, 1)
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2d2 = nn.Conv2d(2, 4, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(12544, num_classes)
        self._weight_init()

    def _weight_init(self):
        for layer in self.children():
            if (layer.__class__.__name__ == 'Linear' \
                or layer.__class__.__name__ == 'Conv2d'):
                nn.init.kaiming_uniform_(layer.weight.data,0.2)
                nn.init.constant_(layer.bias.data,0)
            

    def forward(self, x):
        x = self.conv2d1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2d2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc(x)
        # x = x.view(-1,1)
        # x = x.view(-1)
        return x

## Train featurizers and classifiers together for all environments

In [213]:
lr = 0.01
n_epochs = 10
birds = get_place()

land_net = BasicConvNet(num_classes=2)
water_net = BasicConvNet(num_classes=2)
land_optimizer = optim.Adam(land_net.parameters(), lr=lr)
water_optimizer = optim.Adam(water_net.parameters(), lr=lr)
water_optimizer.zero_grad()
land_optimizer.zero_grad()
criterion = nn.CrossEntropyLoss()

for epoch in tqdm(range(n_epochs)):
    for batch_idx, data in tqdm(
        enumerate(train_loader), total=len(train_loader)
    ):
        inputs, labels = data

        sample_fname, _ = train_loader.dataset.samples[batch_idx]
        fname = sample_fname.split('/')[-1]

        if birds[fname] == 0:
            # Train land-env birds
            land_optimizer.zero_grad()
            outputs = land_net(inputs)
            land_loss = criterion(outputs, labels.type(torch.LongTensor)) 
            land_loss.backward()
            land_optimizer.step()
            
        elif birds[fname] == 1:
            # Train water-env birds
            water_optimizer.zero_grad()
            outputs = water_net(inputs)
            water_loss = criterion(outputs, labels.type(torch.LongTensor)) 
            water_loss.backward()
            water_optimizer.step()


    print(f"E: {epoch}; Land Loss: {land_loss.item()}; Water Loss: {water_loss.item()}")


100%|██████████| 30/30 [00:01<00:00, 16.62it/s]
 10%|█         | 1/10 [00:01<00:16,  1.81s/it]

E: 0; Land Loss: 8.124662399291992; Water Loss: 0.07288221269845963


100%|██████████| 30/30 [00:01<00:00, 17.37it/s]
 20%|██        | 2/10 [00:03<00:14,  1.77s/it]

E: 1; Land Loss: 0.6634581089019775; Water Loss: 0.6668913960456848


100%|██████████| 30/30 [00:01<00:00, 17.34it/s]
 30%|███       | 3/10 [00:05<00:12,  1.75s/it]

E: 2; Land Loss: 0.689561665058136; Water Loss: 0.6551748514175415


100%|██████████| 30/30 [00:01<00:00, 18.51it/s]
 40%|████      | 4/10 [00:06<00:10,  1.71s/it]

E: 3; Land Loss: 0.7043549418449402; Water Loss: 0.6436312198638916


100%|██████████| 30/30 [00:01<00:00, 18.83it/s]
 50%|█████     | 5/10 [00:08<00:08,  1.67s/it]

E: 4; Land Loss: 0.7180204391479492; Water Loss: 0.6340154409408569


100%|██████████| 30/30 [00:01<00:00, 19.30it/s]
 60%|██████    | 6/10 [00:10<00:06,  1.63s/it]

E: 5; Land Loss: 0.7258651256561279; Water Loss: 0.6262399554252625


100%|██████████| 30/30 [00:01<00:00, 18.69it/s]
 70%|███████   | 7/10 [00:11<00:04,  1.63s/it]

E: 6; Land Loss: 0.7230496406555176; Water Loss: 0.6199921369552612


100%|██████████| 30/30 [00:02<00:00, 14.80it/s]
 80%|████████  | 8/10 [00:13<00:03,  1.76s/it]

E: 7; Land Loss: 0.6764634251594543; Water Loss: 0.6149795651435852


100%|██████████| 30/30 [00:01<00:00, 18.09it/s]
 90%|█████████ | 9/10 [00:15<00:01,  1.73s/it]

E: 8; Land Loss: 0.45330893993377686; Water Loss: 0.6109597086906433


100%|██████████| 30/30 [00:01<00:00, 20.28it/s]
100%|██████████| 10/10 [00:16<00:00,  1.69s/it]

E: 9; Land Loss: 0.35834068059921265; Water Loss: 0.6077364087104797





## Train classifiers for environments only, freezing featurizers

In [212]:
for layer in land_net.children():
    if layer.__class__.__name__ == 'Linear':
        for p in layer.parameters():
            p.requires_grad = False

for param in water_net.parameters():
    if layer.__class__.__name__ == 'Linear':
        for p in layer.parameters():
            p.requires_grad = False

l_optim_frozen = optim.SGD(filter(lambda p: p.requires_grad, land_net.parameters()), lr=lr)
w_optim_frozen = optim.SGD(filter(lambda p: p.requires_grad, water_net.parameters()), lr=lr)