In [1]:
# %%
# Rice Variety Classification Notebook
# Combined code for variety classification using CBAM-ResNet18

# %%
# 1. Imports
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torchvision.io import decode_image
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
from timm.data import Mixup
from torchmetrics.classification import MulticlassF1Score

# %%
# 2. Variety Information (optional)
VARIETY_INFO = {
    "Basmati": {
        "origin": "India/Pakistan",
        "characteristics": "Long grain, aromatic",
        "growing_period": "120-150 days",
        "optimal_conditions": "Warm climate, well-drained soil"
    },
    "Jasmine": {
        "origin": "Thailand",
        "characteristics": "Long grain, fragrant",
        "growing_period": "110-120 days",
        "optimal_conditions": "Tropical climate, abundant water"
    },
    "Arborio": {
        "origin": "Italy",
        "characteristics": "Medium grain, high starch content",
        "growing_period": "130-150 days",
        "optimal_conditions": "Temperate climate, consistent water"
    },
    "Sushi": {
        "origin": "Japan",
        "characteristics": "Short grain, sticky when cooked",
        "growing_period": "120-140 days",
        "optimal_conditions": "Temperate climate, consistent water level"
    },
    "Long Grain": {
        "origin": "Various regions",
        "characteristics": "Long and slender grain, fluffy when cooked",
        "growing_period": "110-130 days",
        "optimal_conditions": "Warm climate, good irrigation"
    }
}

# %%
# 3. Image Transform

def get_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

# %%
# 4. Attention Modules & Model Definition
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(in_channels // reduction_ratio, in_channels)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_feats = self.max_pool(x)
        avg_feats = self.avg_pool(x)
        max_feats = torch.flatten(max_feats, 1)
        avg_feats = torch.flatten(avg_feats, 1)
        max_out = self.mlp(max_feats)
        avg_out = self.mlp(avg_feats)
        scale = self.sigmoid(max_out + avg_out).unsqueeze(2).unsqueeze(3)
        return x * scale


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size,
                              padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_result, _ = torch.max(x, dim=1, keepdim=True)
        avg_result = torch.mean(x, dim=1, keepdim=True)
        result = torch.cat([max_result, avg_result], dim=1)
        attn = self.sigmoid(self.conv(result))
        return x * attn


class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(in_channels, reduction_ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.ca(x)
        x = self.sa(x)
        return x


class CBAMResNet18(nn.Module):
    def __init__(self, num_classes, in_channels=3):
        super().__init__()
        base = models.resnet18(weights=None)
        if in_channels != 3:
            base.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7,
                                    stride=2, padding=3, bias=False)
        self.stem = nn.Sequential(base.conv1, base.bn1,
                                  base.relu, base.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = (
            base.layer1, base.layer2, base.layer3, base.layer4)
        self.cbam1 = CBAM(64)
        self.cbam2 = CBAM(128)
        self.cbam3 = CBAM(256)
        self.cbam4 = CBAM(512)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.cbam1(self.layer1(x))
        x = self.cbam2(self.layer2(x))
        x = self.cbam3(self.layer3(x))
        x = self.cbam4(self.layer4(x))
        x = self.pool(x)
        return self.classifier(x)

# %%
# 5. Dataset & DataLoader for Variety Classification
class RiceDataset(Dataset):
    def __init__(self, image_dir, labels_path, split="train",
                 transform=None, val_size=0.2, random_seed=42,
                 oversample=False):
        df = pd.read_csv(labels_path)
        from sklearn.model_selection import train_test_split
        train_df, val_df = train_test_split(
            df, test_size=val_size, stratify=df["variety"],
            random_state=random_seed)
        self.metadata = train_df if split == "train" else val_df
        if oversample and split == "train":
            from sklearn.utils import resample
            class_dfs = []
            max_size = self.metadata["variety"].value_counts().max()
            for cls, grp in self.metadata.groupby("variety"):
                up = resample(grp, replace=True,
                              n_samples=max_size,
                              random_state=random_seed)
                class_dfs.append(up)
            self.metadata = pd.concat(class_dfs).sample(frac=1,
                               random_state=random_seed)
        self.image_dir = image_dir
        self.transform = transform
        self.classes = sorted(self.metadata["variety"].unique())
        self.class_to_idx = {c:i for i,c in enumerate(self.classes)}
        self.image_paths = []
        self.targets = []
        for _, row in self.metadata.iterrows():
            folder = row["label"]
            img_id = row["image_id"]
            path = os.path.join(image_dir, folder, img_id)
            self.image_paths.append(path)
            self.targets.append(self.class_to_idx[row["variety"]])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = decode_image(self.image_paths[idx])
        img = to_pil_image(img)
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.targets[idx], dtype=torch.long)
        return img, label


def get_dataloaders(image_dir, labels_path, batch_size=32,
                    val_size=0.2, oversample=False):
    train_ds = RiceDataset(image_dir, labels_path, split="train",
                            transform=get_transform(),
                            val_size=val_size, oversample=oversample)
    val_ds   = RiceDataset(image_dir, labels_path, split="val",
                            transform=get_transform(),
                            val_size=val_size)
    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              shuffle=True, num_workers=4,
                              pin_memory=True,
                              persistent_workers=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size,
                              shuffle=False, num_workers=4,
                              pin_memory=True,
                              persistent_workers=True)
    return train_loader, val_loader

# %%
# 6. Training Utilities & Trainer
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.best_state = None

    def early_stop(self, loss, model):
        if loss < self.best_loss - self.min_delta:
            self.best_loss = loss
            self.counter = 0
            self.best_state = model.state_dict()
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Optional Mixup (set mixup=False if not needed)
mixup_fn = Mixup(
    mixup_alpha=0.4,
    cutmix_alpha=1.0,
    prob=0.5,
    switch_prob=0.5,
    mode='batch',
    label_smoothing=0.1,
    num_classes=len(VARIETY_INFO)
)

class Trainer:
    def __init__(self, model, loss_fn, optimizer, metric,
                 device, model_name, scheduler=None, save=True,
                 mixup=False):
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.metric = metric
        self.device = device
        self.scheduler = scheduler
        self.save = save
        self.model_name = model_name
        self.mixup = mixup
        self.history = {"train_loss":[], "val_loss":[],
                        "train_f1":[], "val_f1":[], "lr":[]}

    def train_epoch(self, loader):
        self.model.train()
        total_loss = 0
        self.metric.reset()
        for imgs, labels in loader:
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            if self.mixup:
                imgs, labels = mixup_fn(imgs, labels)
            preds = self.model(imgs)
            loss = self.loss_fn(preds, labels)
            total_loss += loss.item()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.metric(preds.argmax(1), labels)
        avg_loss = total_loss / len(loader)
        f1 = self.metric.compute().item()
        lr = self.optimizer.param_groups[0]['lr']
        self.history['train_loss'].append(avg_loss)
        self.history['train_f1'].append(f1)
        self.history['lr'].append(lr)
        return avg_loss, f1

    def val_epoch(self, loader):
        self.model.eval()
        total_loss = 0
        self.metric.reset()
        with torch.no_grad():
            for imgs, labels in loader:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                preds = self.model(imgs)
                loss = self.loss_fn(preds, labels)
                total_loss += loss.item()
                self.metric(preds.argmax(1), labels)
        avg_loss = total_loss / len(loader)
        f1 = self.metric.compute().item()
        self.history['val_loss'].append(avg_loss)
        self.history['val_f1'].append(f1)
        return avg_loss, f1

    def fit(self, train_loader, val_loader, epochs=10):
        stopper = EarlyStopping()
        for epoch in range(epochs):
            tr_loss, tr_f1 = self.train_epoch(train_loader)
            val_loss, val_f1 = self.val_epoch(val_loader)
            if stopper.early_stop(val_loss, self.model):
                print("Early stopping at epoch", epoch+1)
                break
        if self.save and stopper.best_state:
            torch.save(stopper.best_state, f"{self.model_name}.pt")
            pd.DataFrame(self.history).to_csv(f"{self.model_name}_history.csv", index=False)
        return self.history

# %%
# 7. Example Training & Inference
# Paths (change to your data)
HOME_PATH = os.getcwd() + "/"
image_dir = HOME_PATH + 'train_images'
labels_csv = pd.read_csv(HOME_PATH + "meta_train.csv")

# Get data
train_loader, val_loader = get_dataloaders(image_dir, labels_csv,
                                           batch_size=32, oversample=True)

# Setup model, loss, optimizer, metric
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CBAMResNet18(num_classes=len(VARIETY_INFO)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
metric = MulticlassF1Score(num_classes=len(VARIETY_INFO)).to(device)

# Train
trainer = Trainer(model, criterion, optimizer, metric, device,
                  model_name="variety_model", mixup=False)
history = trainer.fit(train_loader, val_loader, epochs=10)

# %%
# Inference Function

def predict_variety(image_path, model, device):
    img = Image.open(image_path).convert('RGB')
    tensor = get_transform()(img).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        out = model(tensor)
        probs = F.softmax(out, dim=1)[0]
        idx = torch.argmax(probs).item()
        names = list(VARIETY_INFO.keys())
        return names[idx], probs[idx].item()

# Usage
# pred_name, pred_conf = predict_variety("test.jpg", model, device)
# print(f"Predicted: {pred_name} ({pred_conf*100:.2f}%)")


ModuleNotFoundError: No module named 'torch'