In [104]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import sampler
from PIL import Image

import torchvision.datasets as dset
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import methodtools
import numpy as np
import os
import random
import pandas as pd
import pathlib
import pickle
import timeit

In [24]:
from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling

In [80]:
class PlantSeedlingDataset(Dataset):
    def __init__(self, root_dir, train=True):
        self.train = train
        self.root_dir = root_dir
        self.labels = [dir for dir in os.listdir(root_dir) if dir != 'encoding']
        self.label_map = {label: i for i, label in enumerate(self.labels)}
        self.feature_extractor = \
            BeitFeatureExtractor.from_pretrained('microsoft/beit-large-patch16-224-pt22k')
        self.bert_model = \
            BeitForMaskedImageModeling.from_pretrained('microsoft/beit-large-patch16-224-pt22k')
        
        images_with_label = []
        for label in self.labels:
            label_id = self.label_map[label]
            for image in os.listdir(os.path.join(root_dir, label)):
                images_with_label.append((image, label_id))
        
        random.seed(42)
        random.shuffle(images_with_label)
        self.images = images_with_label
        
        self.encoding_dir = os.path.join(root_dir, 'encoding')
        pathlib.Path(self.encoding_dir).mkdir(parents=True, exist_ok=True)
        
        for name, label in self.images:
            if os.path.isfile(os.path.join(self.encoding_dir, f'{name}.pt')):
                continue
            print(f"Processing {self.labels[label]}/{name}...")
            image = Image.open(os.path.join(self.root_dir, self.labels[label], name))
            image = TF.to_tensor(image)
            if image.shape[0] > 3:
                image = image[:3, :, :]
            image = TF.resize(image, (224, 224))
            image = TF.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            inputs = self.feature_extractor(images=image, return_tensors="pt")
            outputs = self.bert_model(**inputs)
            encoding = torch.tensor(outputs.logits, dtype=torch.bfloat16)
            torch.save(encoding, os.path.join(self.encoding_dir, f'{name}.pt'))        

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        name, label_id = self.images[idx]
        encoding = torch.load(os.path.join(self.encoding_dir, f'{name}.pt'))
        return {'encoding': encoding, 'label': label_id}

In [81]:
ds = PlantSeedlingDataset("data/plant-seedling-raw/")

In [72]:
len(ds)

4750

In [73]:
class ChunkSampler(sampler.Sampler):
    """Samples elements sequentially from some offset. 
    Arguments:
        num_samples: # of desired datapoints
        start: offset where we should start selecting from
    """
    def __init__(self, num_samples, start = 0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples

NUM_TRAIN = 4000
NUM_VAL = 750
loader_train = DataLoader(ds, batch_size=32, sampler=ChunkSampler(NUM_TRAIN, 0))
loader_val = DataLoader(ds, batch_size=32, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))

In [52]:
dtype = torch.BFloat16Tensor
model = nn.Sequential(
    nn.Flatten(), # see above for explanation
    nn.Linear(np.prod(list(ds[0]['encoding'].shape)), len(ds.labels)), # affine layer
)

model.type(dtype)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=1605632, out_features=12, bias=True)
)

In [53]:
gpu_dtype = torch.cuda.BFloat16Tensor

In [54]:
print_every = 25

def train(model, loss_fn, optimizer, num_epochs = 1):
    for epoch in range(num_epochs):
        print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        model.train()
        for t, data in enumerate(loader_train):
            x_var = Variable(data['encoding']).to("cuda")
            y_var = Variable(data['label']).to("cuda")
            scores = model(x_var)
            loss = loss_fn(scores, y_var)
            if (t + 1) % print_every == 0:
                print('t = %d, loss = %.4f' % (t + 1, loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [78]:
@torch.no_grad()
def check_accuracy(model, loader):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval() # Put the model in test mode (the opposite of model.train(), essentially)
    for data in loader:
        x_var = Variable(data['encoding']).to("cuda")
        y_var = Variable(data['label'])

        scores = model(x_var)
        _, preds = scores.data.cpu().max(1)
        num_correct += (preds == y_var).sum()
        num_samples += preds.size(0)
    acc = float(num_correct) / num_samples
    print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc

In [56]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3) 
best_acc = 0.0

In [57]:
model = model.to("cuda")

In [79]:
for epoch in range(100):
    train(model, loss_fn, optimizer, num_epochs=1)
    acc = check_accuracy(model, loader_val)
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "bert+linear.pt")

Starting epoch 1 / 1
t = 25, loss = 44.0000
t = 50, loss = 336.0000
t = 75, loss = 188.0000
t = 100, loss = 360.0000
t = 125, loss = 412.0000
Checking accuracy on validation set
Got 338 / 750 correct (45.07)
Starting epoch 1 / 1
t = 25, loss = 568.0000
t = 50, loss = 480.0000
t = 75, loss = 155.0000
t = 100, loss = 412.0000
t = 125, loss = 328.0000
Checking accuracy on validation set
Got 418 / 750 correct (55.73)
Starting epoch 1 / 1
t = 25, loss = 101.0000
t = 50, loss = 260.0000
t = 75, loss = 62.0000
t = 100, loss = 298.0000
t = 125, loss = 106.0000
Checking accuracy on validation set
Got 435 / 750 correct (58.00)
Starting epoch 1 / 1
t = 25, loss = 39.0000
t = 50, loss = 96.0000
t = 75, loss = 128.0000
t = 100, loss = 53.0000
t = 125, loss = 72.0000
Checking accuracy on validation set
Got 461 / 750 correct (61.47)
Starting epoch 1 / 1
t = 25, loss = 54.0000
t = 50, loss = 110.0000
t = 75, loss = 39.0000
t = 100, loss = 8.0000
t = 125, loss = 56.0000
Checking accuracy on validation 

KeyboardInterrupt: 

In [82]:
class PlantSeedlingTestDataset(Dataset):
    def __init__(self, root_dir):
        self.train = False
        self.root_dir = root_dir
        self.feature_extractor = \
            BeitFeatureExtractor.from_pretrained('microsoft/beit-large-patch16-224-pt22k')
        self.bert_model = \
            BeitForMaskedImageModeling.from_pretrained('microsoft/beit-large-patch16-224-pt22k')
        
        images = [file for file in os.listdir(self.root_dir) if file != 'encoding']
        self.images = images
        
        self.encoding_dir = os.path.join(root_dir, 'encoding')
        pathlib.Path(self.encoding_dir).mkdir(parents=True, exist_ok=True)
        
        for name in self.images:
            if os.path.isfile(os.path.join(self.encoding_dir, f'{name}.pt')):
                continue
            print(f"Processing {name}...")
            image = Image.open(os.path.join(self.root_dir, name))
            image = TF.to_tensor(image)
            if image.shape[0] > 3:
                image = image[:3, :, :]
            image = TF.resize(image, (224, 224))
            image = TF.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            inputs = self.feature_extractor(images=image, return_tensors="pt")
            outputs = self.bert_model(**inputs)
            encoding = torch.tensor(outputs.logits, dtype=torch.bfloat16)
            torch.save(encoding, os.path.join(self.encoding_dir, f'{name}.pt'))        

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        encoding = torch.load(os.path.join(self.encoding_dir, f'{self.images[idx]}.pt'))
        return encoding

In [83]:
tds = PlantSeedlingTestDataset("data/plant-seedling-test/")

Processing 2053ada02.png...


  encoding = torch.tensor(outputs.logits, dtype=torch.bfloat16)


Processing 917aa970b.png...
Processing 46c4aed02.png...
Processing 817aacd06.png...
Processing 25c8aa30c.png...
Processing 3f13ad205.png...
Processing adb2a8107.png...
Processing b504a070f.png...
Processing 71c7ad202.png...
Processing 3217a1807.png...
Processing a530a8c00.png...
Processing e246a570b.png...
Processing 05e5a5c06.png...
Processing c9f6afa0c.png...
Processing 3b5ba1404.png...
Processing 1750ad208.png...
Processing c621aa904.png...
Processing eca8ad508.png...
Processing c2c1a8707.png...
Processing c395a1d0a.png...
Processing ee5ca2c09.png...
Processing 62d1a6303.png...
Processing c41ba990c.png...
Processing 2cb4abe06.png...
Processing cfb3a6509.png...
Processing caaeac306.png...
Processing 28a2a7408.png...
Processing 0975a0204.png...
Processing ab41ae606.png...
Processing 425eab609.png...
Processing 5423ad505.png...
Processing bee1a8206.png...
Processing c5f0adc04.png...
Processing fc58a690c.png...
Processing 5dfaa9101.png...
Processing 7863a4408.png...
Processing 3bf4a0c04

In [94]:
loader_test = DataLoader(tds, batch_size=64)

In [122]:
@torch.no_grad()
def save_result(model, loader):
    num_correct = 0
    num_samples = 0
    model.eval() # Put the model in test mode (the opposite of model.train(), essentially)
    result = []
    for x in loader:
        x_var = Variable(x.type(gpu_dtype))

        scores = model(x_var)
        _, preds = scores.data.cpu().max(1)
        result += preds

    classifications = [ds.labels[int(num)] for num in result]
    df = pd.DataFrame({"file": tds.images, "species": classifications})
    df = df.sort_values(["file"])
    df.to_csv("submission.csv", index=False)

In [123]:
save_result(model, loader_test)