In [None]:
import os
import torch
import pickle
import numpy as np
import pandas as pd

from PIL import Image
import torch.nn as nn
from tqdm import tqdm
import lightgbm as lgb

from torchvision import transforms
from sklearn.model_selection import KFold
from torchvision import models, transforms
from sklearn.preprocessing import LabelEncoder
from torchvision.models import efficientnet_b0
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, Dataset
from sklearn.metrics import mean_squared_error, mean_absolute_error

In [None]:
class CONFIG:
    SEED = 67

    TRAIN_PATH = '/kaggle/input/csiro-biomass/train.csv'
    TEST_PATH =  '/kaggle/input/csiro-biomass/test.csv'

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    weights = {
        "Dry_Clover_g": 0.1,
        "Dry_Dead_g": 0.1,
        "Dry_Green_g": 0.1,
        "Dry_Total_g": 0.5,
        "GDM_g": 0.2
    }

cfg = CONFIG()

In [None]:
class BiomassDataset(Dataset):
    def __init__(self, df, train=True):
        self.train = train
        self.df = df
        self.tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            
        ])
        
    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        path = self.df['image_path'].iloc[idx]
        img = Image.open(f'/kaggle/input/csiro-biomass/{path}').convert("RGB")
        img = self.tf(img)

        if self.train:
            targets = torch.tensor(self.df[['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']].iloc[idx].to_numpy(), dtype=torch.float)
            return img, targets
        else:
            return img

In [None]:
def weighted_r2_torch(y_true, y_pred, w):
    y_bar = (w * y_true).sum(dim=1, keepdim=True)
    ss_res = (w * (y_true - y_pred) ** 2).sum()
    ss_tot = (w * (y_true - y_bar) ** 2).sum()
    return 1 - ss_res / ss_tot

In [None]:
def clean_ids(data):
    return data.split('__')[0]
    
def preprocessing(data):
    data['sample_id'] = data['sample_id'].apply(clean_ids)

    if 'target' in data.columns:
        return data.pivot_table(
            index=[
                'sample_id',
                'image_path'
            ],
                columns='target_name', 
                values='target'
            ).reset_index()

    data = data[['sample_id', 'image_path']]
    return data.drop_duplicates()

In [None]:
train = pd.read_csv(cfg.TRAIN_PATH)
train = preprocessing(train)
train, val = train_test_split(train, test_size=0.3)

In [None]:
test = pd.read_csv(cfg.TEST_PATH)
test = preprocessing(test)

In [None]:
train_dataset = BiomassDataset(train)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

val_dataset = BiomassDataset(val)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)


test_dataset = BiomassDataset(test, train=False)

In [None]:
model = efficientnet_b0(weights="IMAGENET1K_V1")
model.classifier[1] = torch.nn.Sequential(
    torch.nn.Linear(1280, 5),
    torch.nn.ReLU()
)
model = model.to(cfg.device)

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score > self.best_score + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), 'efficientnet_b7_checkpoint.pt')
        self.val_loss_min = val_loss

In [None]:
early_stopping = EarlyStopping(patience=100, verbose=True)

In [None]:
criterion = nn.MSELoss(reduction="none")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

w = torch.tensor(list(cfg.weights.values()), dtype=torch.float32)
w = w / w.sum()

for epoch in range(900):
    model.train()
    total_loss = 0

    for x, y in train_loader:
        x = x.to(cfg.device)
        y = y.to(cfg.device)

        optimizer.zero_grad()
        preds = model(x)
        
        loss = (criterion(preds, y) * w.to(cfg.device).unsqueeze(0)).mean()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    model.eval()
    y_true_all = []
    y_pred_all = []

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(cfg.device)
            y = y.to(cfg.device)
            preds = model(x)

            y_true_all.append(y)
            y_pred_all.append(preds)

    y_true_all = torch.cat(y_true_all)
    y_pred_all = torch.cat(y_pred_all)

    r2 = weighted_r2_torch(y_true_all, y_pred_all, w.to(cfg.device))

    early_stopping(r2, model)

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break
    
    print(f"Epoch {epoch+1} | Loss {total_loss:.4f} | Weighted R2 {r2:.4f}")