In [None]:
from pathlib import Path
import json
import random
from typing import Iterable, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import GroupShuffleSplit

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from tqdm.auto import tqdm

PROJECT_DIR = Path('.').resolve()
DATA_DIR = PROJECT_DIR / 'pig_posture_recognition'
TRAIN_IMAGES = DATA_DIR / 'train_images'
TEST_IMAGES = DATA_DIR / 'test_images'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

# Load metadata
train_df = pd.read_csv(DATA_DIR / 'train.csv')
test_df = pd.read_csv(DATA_DIR / 'test.csv')
class_names = [c.strip() for c in (DATA_DIR / 'pig_posture_classes.txt').read_text().splitlines() if c.strip()]
id_to_name = {i: name for i, name in enumerate(class_names)}
name_to_id = {v: k for k, v in id_to_name.items()}

IMGSZ = 299
BATCH = 16
EPOCHS = 3
LR = 2e-4
NUM_WORKERS = 0  # set 0 to avoid multiprocessing shutdown issues on Kaggle
PIN_MEMORY = False


def bbox_to_xyxy(bbox: Iterable[float], img_w: int, img_h: int) -> Tuple[int, int, int, int]:
    x_c, y_c, w, h = bbox
    x1 = max(0, x_c - w / 2)
    y1 = max(0, y_c - h / 2)
    x2 = min(img_w, x_c + w / 2)
    y2 = min(img_h, y_c + h / 2)
    return int(x1), int(y1), int(x2), int(y2)


train_tfms = T.Compose([
    T.Resize(int(IMGSZ * 1.15)),
    T.RandomResizedCrop(IMGSZ, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.05),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_tfms = T.Compose([
    T.Resize(int(IMGSZ * 1.15)),
    T.CenterCrop(IMGSZ),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


class PigPostureDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transform, img_root: Path, has_label: bool = True):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.img_root = img_root
        self.has_label = has_label

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        bbox = json.loads(row['bbox'])
        x1, y1, x2, y2 = bbox_to_xyxy(bbox, row['width'], row['height'])
        with Image.open(self.img_root / row['image_id']).convert('RGB') as im:
            crop = im.crop((x1, y1, x2, y2))
        image = self.transform(crop)
        if self.has_label:
            return image, int(row['class_id'])
        return image, row['row_id']


splitter = GroupShuffleSplit(test_size=0.2, n_splits=1, random_state=SEED)
train_idx, val_idx = next(splitter.split(train_df, groups=train_df['image_id']))
train_split = train_df.iloc[train_idx].copy()
val_split = train_df.iloc[val_idx].copy()

train_ds = PigPostureDataset(train_split, train_tfms, TRAIN_IMAGES, has_label=True)
val_ds = PigPostureDataset(val_split, val_tfms, TRAIN_IMAGES, has_label=True)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# Model
base_model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, aux_logits=True)
base_model.fc = nn.Linear(base_model.fc.in_features, len(class_names))
base_model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(base_model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


def run_epoch(model, loader, train_mode: bool):
    model.train(mode=train_mode)
    running_loss, correct, total = 0.0, 0, 0
    iterator = tqdm(loader, leave=False, desc='train' if train_mode else 'val')
    for images, targets in iterator:
        images, targets = images.to(DEVICE), targets.to(DEVICE)
        with torch.set_grad_enabled(train_mode):
            outputs = model(images)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, targets)
            if train_mode:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        running_loss += loss.item() * images.size(0)
        correct += (outputs.argmax(1) == targets).sum().item()
        total += targets.size(0)
        iterator.set_postfix(loss=running_loss / total, acc=correct / total)
    return running_loss / total, correct / total


best_acc = 0.0
best_path = PROJECT_DIR / 'best_inception.pt'
for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = run_epoch(base_model, train_loader, train_mode=True)
    val_loss, val_acc = run_epoch(base_model, val_loader, train_mode=False)
    scheduler.step()
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({'model_state': base_model.state_dict(), 'acc': val_acc}, best_path)
    print(f'Epoch {epoch}/{EPOCHS} | train loss {train_loss:.4f} acc {train_acc:.3f} | val loss {val_loss:.4f} acc {val_acc:.3f}')

# Inference and submission
checkpoint = torch.load(best_path, map_location=DEVICE)
base_model.load_state_dict(checkpoint['model_state'])
base_model.eval()

test_ds = PigPostureDataset(test_df, val_tfms, TEST_IMAGES, has_label=False)
test_loader = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

preds, row_ids = [], []
with torch.no_grad():
    for images, ids in tqdm(test_loader, leave=False, desc='infer'):
        images = images.to(DEVICE)
        outputs = base_model(images)
        if isinstance(outputs, tuple):
            outputs = outputs[0]
        pred = outputs.argmax(1).cpu().tolist()
        preds.extend(pred)
        row_ids.extend(ids)

submission = pd.DataFrame({'row_id': row_ids, 'class_id': preds})
submission.to_csv('submission.csv', index=False)
submission.head()
