<a href="https://colab.research.google.com/github/claireluo66/birds_dataset/blob/main/SSL_birds_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.io import read_image, ImageReadMode
import torch.optim as optim

from collections import namedtuple
from operator import itemgetter
from itertools import groupby
from pathlib import Path
from tqdm.notebook import tqdm

## Data

In [2]:
!gdown --id 10X4KdsEsqhuZBID-iMAaZMmK6BYZXIVI
!tar -xf CUB_200_2011.tgz

Downloading...
From: https://drive.google.com/uc?id=10X4KdsEsqhuZBID-iMAaZMmK6BYZXIVI
To: /content/CUB_200_2011.tgz
1.15GB [00:10, 107MB/s]


In [3]:
#upload train_test_val_split text file
from google.colab import files
uploaded = files.upload()

Saving train_test_val_split.txt to train_test_val_split.txt


In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [5]:
def parse_datatype(x):
    for datatype in (int, float):
        try:
            return datatype(x)
        except ValueError:
            pass
    return x

def parse_and_filter(x, ids_load=None):
    for line in x:
        a = line.split()
        if ids_load is None or int(a[0]) in ids_load:
            yield [parse_datatype(i) for i in a]

In [6]:
class CUB(Dataset):
    def __init__(
        self,
        dataset_path,
        split_file_path="train_test_val_split.txt",
        *,
        type=0,
        images=True,
        labels=True,
        attributes=True,
        #attribute_classes=False,
        transform=None,
        show_progress=True,
    ):
        root = Path(dataset_path)
        self.transform = transform

        with open(split_file_path) as f:
            self.ids = [id for id, type_ in parse_and_filter(f) if type_ == type]
            ids = set(self.ids)

        if show_progress:
            t = lambda x, desc: tqdm(x, total=len(ids), desc=desc)
        else:
            t = lambda x, _: x

        if images:
            #image paths
            with open(root / "images.txt") as f:
                self.image_paths = {
                    id: root / "images" / name
                    for id, name in t(parse_and_filter(f, ids), "images")
                }

        if labels:
            #image labels
            with open(root / "image_class_labels.txt") as f:
                self.labels = {id: label - 1 for id, label in t(parse_and_filter(f, ids), "labels")}


        if attributes:

            def issue(lines):
                for line in lines:
                    if len(line) > 5:
                        yield line[:4] + line[5:]
                    else:
                        yield line

            with open(root / "attributes" / "image_attribute_labels.txt") as f:
                self.attributes = {
                    id: torch.tensor([float(present) for _, _, present, _, _ in issue(lines)])
                    for id, lines in t(
                        groupby(parse_and_filter(f, ids), key=itemgetter(0)), "attributes"
                    )
                }


           # with open(root / "attributes.txt") as f:
           #     self.attribute_classes = {
           #         id: root / "attributes" / fname
           #         for id, fname in t(parse_and_filter(f, ids), "attribute_classes")
           #     }

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        id = self.ids[idx]
        item = {}

        if hasattr(self, "image_paths"):
            item["image"] = read_image(str(self.image_paths[id]), mode=ImageReadMode.RGB) / 255
            if self.transform:
                item["image"] = self.transform(item["image"])
        if hasattr(self, "labels"):
            item["label"] = self.labels[id]
        if hasattr(self, "attributes"):
            item["attributes"] = self.attributes[id]
        if hasattr(self, "attribute_classes"):
            item["attribute_classes"] = self.attribute_classes[id]

        return item

In [7]:
preprocess = {
    'train': transforms.Compose([ 
        transforms.Resize(256), 
        transforms.RandomRotation(45), 
        transforms.RandomResizedCrop(224), 
        transforms.RandomHorizontalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224), 
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [8]:
train_dataset = CUB("CUB_200_2011", type=0, transform=preprocess['train'])
test_dataset = CUB("CUB_200_2011", type=1, transform=preprocess['test'])
val_dataset = CUB("CUB_200_2011", type=2, transform=preprocess['test'])

HBox(children=(FloatProgress(value=0.0, description='images', max=8232.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='labels', max=8232.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='attributes', max=8232.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='images', max=1773.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='labels', max=1773.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='attributes', max=1773.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='images', max=1783.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='labels', max=1783.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='attributes', max=1783.0, style=ProgressStyle(description_…




## Multimodal Model

In [9]:
class SSLModel(nn.Module):
    def __init__(self):     
        super().__init__()
  
        #MODALITY 1: images      
        self.image_model = models.resnet18()
        input_size = 1000

        self.classification = nn.Sequential(
            nn.Linear(input_size, 300),
            nn.BatchNorm1d(300),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(300, 250)
             )
        
        self.attribute_prediction = [nn.Linear(250, 2) for _ in range(312)]
        self.attribute_prediction = nn.ModuleList(self.attribute_prediction)

     
        #MODALITY 2: attributes
        self.attribute_model = nn.Sequential(
            nn.Linear(312, 300),
            nn.BatchNorm1d(300),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(300,250)
        )
        
        #final part after fusion
        size = 500
        self.fused = nn.Sequential(
            nn.Linear(size, 250),
            nn.BatchNorm1d(250),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(250, 200)
             )

    def pretrain(self, images):
        output = self.image_model(images)
        output = self.classification(output)

        out_attribute_pred = []
        for i in range(312):
            o = self.attribute_prediction[i](output)
            out_attribute_pred.append(o)
        return out_attribute_pred
    

    def forward(self, images, attributes):
        output1 = self.image_model(images)
        output1 = self.classification(output1) 

        output2 = self.attribute_model(attributes)
        #then concatenate
        combined_output = torch.cat([output1, output2], dim=1)
        return self.fused(combined_output)


In [10]:
SSL_model = SSLModel()

In [11]:
def default_device():
    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    return data.to(device)

class DeviceDataLoader():
    def __init__(self, dataload, *, device):
        self.dataload = dataload
        self.device = device

    def __iter__(self):
        for b in self.dataload: 
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dataload)

device = default_device()
print(f"Using Device: {device}")

SSL_model.to(device)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
train_dataloader = DeviceDataLoader(train_dataloader, device=device)

test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
test_dataloader = DeviceDataLoader(test_dataloader, device=device)

val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)
val_dataloader = DeviceDataLoader(val_dataloader, device=device)

Using Device: cuda:0


## SSL Pretext Task Training

In [12]:
!pip install pkbar
from pkbar import Kbar

Collecting pkbar
  Downloading pkbar-0.5-py3-none-any.whl (9.2 kB)
Installing collected packages: pkbar
Successfully installed pkbar-0.5


In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(SSL_model.parameters(), lr=0.02, momentum=0.9)

In [14]:
#TRAINING FOR PRETRAIN

EPOCHS = 5
PRINT_EVERY = 2

for epoch in range(EPOCHS):
    kbar = Kbar(target=len(train_dataloader) + len(val_dataloader), epoch=epoch, num_epochs=EPOCHS)

#training
    SSL_model.train()
    for i, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = SSL_model.pretrain(batch["image"]) 

        loss = 0
        for j in range(312):
            l = criterion(outputs[j], batch["attributes"][:, j].long()) 
            loss += l
        loss = loss / 64

        loss.backward()
        optimizer.step()

        acc = 0
        for j in range(312):
            pred = outputs[j].argmax(dim=1)
            a = torch.sum(pred == batch["attributes"][:, j].long()) / len(pred)
            acc += a
        acc = acc / 312

        kbar.update(i, values=[("loss", loss.item()), ("acc", acc)])

#validation

    score = SSL_model.eval()
    with torch.no_grad():
        for i, batch in enumerate(val_dataloader, start=len(train_dataloader)):
            outputs = SSL_model.pretrain(batch["image"])
            loss = 0
            for j in range(312):
              l = criterion(outputs[j], batch["attributes"][:, j].long()) 
              loss += l

            acc = 0
            for j in range(312):
                pred = outputs[j].argmax(dim=1)
                a = torch.sum(pred == batch["attributes"][:, j].long()) / len(pred)
                acc += a
            acc = acc / 312

            kbar.update(i, values=[("val_loss", loss.item()), ("val_acc", acc)])

    kbar.add(1)

Epoch: 1/5


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: 2/5
Epoch: 3/5
Epoch: 4/5
Epoch: 5/5


## Downstream Task Training

In [15]:
EPOCHS = 5
PRINT_EVERY = 2

for epoch in range(EPOCHS):
    kbar = Kbar(target=len(train_dataloader) + len(val_dataloader), epoch=epoch, num_epochs=EPOCHS)

    # Train model

    SSL_model.train()
    for i, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = SSL_model(batch["image"], batch["attributes"]) 

        #loss = 0
        #for j in range(312):
        loss = criterion(outputs, batch["label"]) 
            #loss += l

        loss.backward()
        optimizer.step()

        #acc = 0
        #for j in range(312):
            #pred = outputs[j].argmax(dim=1)
            #a = torch.sum(pred == batch["attributes"][:, j].long()) / len(pred)
            #acc += a
        #acc = acc / 312

        preds = F.softmax(outputs, dim=1).argmax(dim=1)
        acc = torch.sum(batch["label"] == preds) / len(preds)

        kbar.update(i, values=[("loss", loss.item()), ("acc", acc)])

    #validation

    score = SSL_model.eval()
    with torch.no_grad():
        for i, batch in enumerate(val_dataloader, start=len(train_dataloader)):
            outputs = SSL_model(batch["image"], batch["attributes"])
            #loss = 0
            #for j in range(312):
            #    l = criterion(outputs[j], batch["attributes"][:, j].long()) 
            #    loss += l
            loss = criterion(outputs, batch["label"]) 

            #acc = 0
            #for j in range(312):
            #    pred = outputs[j].argmax(dim=1)
            #    a = torch.sum(pred == batch["attributes"][:, j].long()) / len(pred)
            #    acc += a
            #acc = acc / 312

            preds = F.softmax(outputs, dim=1).argmax(dim=1)
            acc = torch.sum(batch["label"] == preds) / len(preds)

            kbar.update(i, values=[("val_loss", loss.item()), ("val_acc", acc)])

    kbar.add(1)

Epoch: 1/10
Epoch: 2/10
Epoch: 3/10
Epoch: 4/10
Epoch: 5/10
Epoch: 6/10
Epoch: 7/10
Epoch: 8/10
Epoch: 9/10
Epoch: 10/10
