In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import Subset, Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import os
from PIL import Image
import pandas

In [5]:
class BirdsDataset(Dataset):
    def __init__(self, gt, img_dir, *, train=True, transform):
        self.images = []
        self.labels = []
        self.transform = transform


        train_gt, val_gt = train_test_split(list(gt.items()), test_size=0.3, shuffle=True, random_state=0)
        gt = train_gt if train else val_gt

        for img_filename, class_id in gt:
            self.images.append(os.path.join(img_dir, img_filename))
            self.labels.append(class_id)


    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        class_id = self.labels[index]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, class_id

In [6]:
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
class MobileNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(self.last_channel, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, num_classes),
            nn.BatchNorm1d(num_classes)
        )

        for child in list(self.model.children())[:-4]:
            for param in child.parameters():
                param.requires_grad = False

    def forward(self, x):
        return self.model(x)

In [None]:
from tqdm import tqdm

In [None]:
@torch.no_grad()
def test(model, criterion, val_loader, device, tqdm_desc):
    val_acc, val_loss = 0.0, 0.0
    model.eval()

    for data, target in tqdm(val_loader, desc=tqdm_desc):
        data = data.to(device)
        target = target.to(device)

        logits = model(data)

        loss = criterion(logits, target)

        val_acc += (logits.argmax(dim=1) == target).sum().item()
        val_loss += loss.item() * target.shape[0]

    val_acc /= len(val_loader.dataset)
    val_loss /= len(val_loader.dataset)

    return val_acc, val_loss

def train_epoch(model, optimizer, criterion, train_loader, device, tqdm_desc):
    val_acc, val_loss = 0.0, 0.0
    model.train()

    for data, target in tqdm(train_loader, desc=tqdm_desc):
        data = data.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        logits = model(data)
        loss = criterion(logits, target)
        loss.backward()
        optimizer.step()

        val_acc += (logits.argmax(dim=1) == target).sum().item()
        val_loss += loss.item() * target.shape[0]

    val_acc /= len(train_loader.dataset)
    val_loss /= len(train_loader.dataset)

    return val_acc, val_loss

def train(model, optimizer, n_epochs, train_loader, val_loader, scheduler=None):
    train_loss_log, train_acc_log, val_loss_log, val_acc_log = [], [], [], []

    for epoch in range(n_epochs):
        train_loss, train_acc = train_epoch(model, optimizer, train_loader)
        val_loss, val_acc = test(model, val_loader)

        train_loss_log.extend(train_loss)
        train_acc_log.extend(train_acc)

        val_loss_log.append(val_loss)
        val_acc_log.append(val_acc)

        print(f"Epoch {epoch}")
        print(f" train loss: {train_loss}, train acc: {train_acc}")
        print(f" val loss: {val_loss}, val acc: {val_acc}\n")

        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()

    return train_loss_log, train_acc_log, val_loss_log, val_acc_log

In [11]:
from run import read_csv

batch_size = 64
train_gt = read_csv("/Users/danny.paleyev/birds_classification/public_tests/00_test_img_input/train/gt.csv")
train_img_dir = "/Users/danny.paleyev/birds_classification/public_tests/00_test_img_input/train/images"
train_dataset = BirdsDataset(train_gt, train_img_dir, train=True, transform=None)
val_dataset = BirdsDataset(train_gt, train_img_dir, train=False, transform=None)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = MobileNet(len(train_dataset)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
criterion = torch.nn.CrossEntropyLoss()
scheduler = None #torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

train_loss_log, train_acc_log, val_loss_log, val_acc_log = train(model, optimizer, criterion, scheduler, train_loader, val_loader, device, 15)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /Users/danny.paleyev/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1108)>