In [None]:
!git clone https://github.com/manhmitcf/data.git

In [None]:
model_name = 'Swin-Tiny'

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision import transforms
import os
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
import pandas as pd
from PIL import Image
from torch.optim.lr_scheduler import StepLR

In [None]:


class FishDatasetWithAugmentation(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, aug_transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.aug_transform = aug_transform
        self.labels = {"Highly Fresh" : 0, "Fresh" : 1, "Not Fresh": 2}  
        # Ki·ªÉm tra d·ªØ li·ªáu ƒë·∫ßu v√†o
        if not os.path.exists(img_dir) :
            raise FileNotFoundError(f"Th∆∞ m·ª•c ·∫£nh '{img_dir}' kh√¥ng t·ªìn t·∫°i.")
        if self.data.empty:
            raise ValueError(f"File CSV '{csv_file}' kh√¥ng ch·ª©a d·ªØ li·ªáu.")
        

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        check = False
        img_name = os.path.join(self.img_dir, self.data.iloc[idx, 2])
        if not os.path.exists(img_name):
            check = True
        if check:
            img_name = os.path.join(self.img_dir, self.data.iloc[idx, 2])
            img_name = img_name.replace('_5', '_#')
            if not os.path.exists(img_name):
                raise FileNotFoundError(f"Kh√¥ng t√¨m th·∫•y ·∫£nh '{img_name}'.")
        try:
            image = Image.open(img_name).convert('RGB')  # ƒê·ªçc v√† chuy·ªÉn ƒë·ªïi ·∫£nh sang RGB
        except FileNotFoundError:
            raise FileNotFoundError(f"Kh√¥ng t√¨m th·∫•y ·∫£nh '{img_name}'.")


        label = self.data.iloc[idx, 1]
        if label not in self.labels:
            raise ValueError(f"Nh√£n '{label}' kh√¥ng h·ª£p l·ªá. Ph·∫£i l√† m·ªôt trong {list(self.labels.keys())}.")
        label = self.labels[label]
        label = torch.tensor(label, dtype=torch.long)
        if self.transform:
            image = self.transform(image)
        elif self.aug_transform:
            image = self.aug_transform(image)
        else:
            raise ValueError("C·∫£ transform v√† aug_transform ƒë·ªÅu l√† None. √çt nh·∫•t m·ªôt trong hai ph·∫£i ƒë∆∞·ª£c cung c·∫•p.")

        return image, label
basic_transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

aug_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=30),
    transforms.ColorJitter(brightness=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
import timm
from timm import create_model

In [None]:
import torch
import copy

class EarlyStopping:
    def __init__(self, patience=20, mode="min"):
        """
        patience: s·ªë epoch kh√¥ng c·∫£i thi·ªán ƒë·ªÉ d·ª´ng
        mode: "min" cho val_loss (c√†ng th·∫•p c√†ng t·ªët), "max" cho val_acc (c√†ng cao c√†ng t·ªët)
        """
        self.patience = patience
        self.mode = mode
        self.best_loss = float("inf")  # ban ƒë·∫ßu gi√° tr·ªã loss r·∫•t l·ªõn
        self.best_acc = -float("inf")  # ban ƒë·∫ßu gi√° tr·ªã accuracy r·∫•t th·∫•p
        self.counter = 0
        self.best_weights = None
        self.early_stop = False

    def check_improvement(self, val_loss, val_acc, model):
        """
        Ki·ªÉm tra c·∫£i thi·ªán d·ª±a tr√™n c·∫£ val_loss v√† val_acc.
        """
        # Ki·ªÉm tra c√≥ c·∫£i thi·ªán `val_loss` ho·∫∑c `val_acc`
        if val_acc > self.best_acc:
            # N·∫øu `val_loss` gi·∫£m ho·∫∑c `val_acc` tƒÉng, c·∫≠p nh·∫≠t m√¥ h√¨nh t·ªët nh·∫•t
            self.best_loss = val_loss
            self.best_acc = val_acc
            self.best_weights = copy.deepcopy(model.state_dict())  # L∆∞u l·∫°i tr·ªçng s·ªë c·ªßa m√¥ h√¨nh
            self.counter = 0  # Reset counter v√¨ ƒë√£ c√≥ c·∫£i thi·ªán
            return True

        # N·∫øu kh√¥ng c√≥ c·∫£i thi·ªán
        self.counter += 1
        if self.counter >= self.patience:
            self.early_stop = True  # N·∫øu kh√¥ng c√≥ c·∫£i thi·ªán sau `patience` epoch th√¨ d·ª´ng
        return False

    def restore_best_model(self, model):
        """Kh√¥i ph·ª•c l·∫°i m√¥ h√¨nh t·ªët nh·∫•t"""
        if self.best_weights is not None:
            model.load_state_dict(self.best_weights)
            print("Restored best model weights.")


In [None]:
import time
# C·∫•u h√¨nh
TRAIN_CSV_PATH = "data/train.csv"
VAL_CSV_PATH = "data/val.csv"
IMG_DIR = "data/images/"
EPOCHS = 100
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
NUM_CLASSES = 3
early_stopper = EarlyStopping(patience=20, mode="min")  # Ho·∫∑c "max" n·∫øu b·∫°n mu·ªën theo d√µi accuracy
epoch = None 
# Dataset v√† DataLoader
train_dataset = FishDatasetWithAugmentation(
    csv_file=TRAIN_CSV_PATH,
    img_dir=IMG_DIR,
    transform=None,
    aug_transform=aug_transform,  
)

val_dataset = FishDatasetWithAugmentation(
    csv_file=VAL_CSV_PATH,
    img_dir=IMG_DIR,
    transform=basic_transform,
)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model + Loss + Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = FishClassifier(num_classes=NUM_CLASSES)
# model.to(device)
model = create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=NUM_CLASSES)
model = model.to(device)
model = torch.nn.DataParallel(model) 

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-2)
# Log ƒë·ªÉ l∆∞u l·∫°i loss/acc
# Log ƒë·ªÉ l∆∞u l·∫°i loss/acc
loss_train, loss_val = [], []
acc_train, acc_val = [], []
total_start_time = time.time()
epoch_times = []
# Training loop
for epoch in range(EPOCHS):
    epoch_start_time = time.time() 
    # Train
    model.train()
    train_running_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for images, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} (Train)"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    
    avg_train_loss = train_running_loss / len(train_dataloader)
    train_accuracy = 100 * train_correct / train_total

    # Validation
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    avg_val_loss = val_running_loss / len(val_dataloader)
    val_accuracy = 100 * val_correct / val_total

    # L∆∞u log
    loss_train.append(avg_train_loss)
    loss_val.append(avg_val_loss)
    acc_train.append(train_accuracy)
    acc_val.append(val_accuracy)

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"    Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
    print(f"    Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")
    # Sau khi t√≠nh avg_val_loss v√† val_accuracy
    if early_stopper.check_improvement(avg_val_loss, val_accuracy, model):
        print("Improved! Saving best model.")
    else:
        print(f"No improvement for {early_stopper.counter} epochs.")

    # N·∫øu kh√¥ng c·∫£i thi·ªán trong `patience` epochs, d·ª´ng s·ªõm v√† kh√¥i ph·ª•c m√¥ h√¨nh t·ªët nh·∫•t
    if early_stopper.early_stop:
        print("Early stopping triggered.")
        early_stopper.restore_best_model(model)
        break
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    epoch_times.append(epoch_duration)  # L∆∞u th·ªùi gian epoch n√†y
# T√≠nh th·ªùi gian k·∫øt th√∫c to√†n b·ªô training
total_end_time = time.time()
total_duration = total_end_time - total_start_time
print(f"\n‚è≥ Total training time: {total_duration/60:.2f} minutes.")
# Trung b√¨nh th·ªùi gian 1 epoch
avg_epoch_time = sum(epoch_times) / len(epoch_times)
print(f"‚è±Ô∏è Average time per epoch: {avg_epoch_time:.2f} seconds (~{avg_epoch_time/60:.2f} minutes)")

In [None]:
import matplotlib.pyplot as plt
import os

# ƒê·∫£m b·∫£o th∆∞ m·ª•c 'result' t·ªìn t·∫°i
os.makedirs('result', exist_ok=True)

# V·∫Ω Loss
plt.figure()
plt.plot(loss_train, label='Train Loss')
plt.plot(loss_val, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.grid(True)
plt.savefig('result/loss_curve.png')  # L∆∞u v√†o file
plt.close()

# V·∫Ω Accuracy
plt.figure()
plt.plot(acc_train, label='Train Accuracy')
plt.plot(acc_val, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy over Epochs')
plt.legend()
plt.grid(True)
plt.savefig('result/accuracy_curve.png')  # L∆∞u v√†o file
plt.close()

In [None]:

torch.save(model.state_dict(), f"result/fish_classifier_{model_name}.pth")
print("ƒê√£ l∆∞u m√¥ h√¨nh!")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import timm
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
)
import matplotlib.pyplot as plt

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

# ==== Load dataset test ====
CSV_PATH = "data/test.csv"
IMG_DIR = "data/images/"

try:
    dataset = FishDatasetWithAugmentation(CSV_PATH, IMG_DIR, transform=basic_transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
except FileNotFoundError as e:
    raise FileNotFoundError(f"L·ªói khi t·∫£i dataset: {e}")

# ==== D·ª± ƒëo√°n v√† t√≠nh metrics ====
all_preds, all_labels = [], []
with torch.no_grad():
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average="macro")
precision = precision_score(all_labels, all_preds, average="macro")
recall = recall_score(all_labels, all_preds, average="macro")

print(f"‚úÖ Accuracy:  {acc:.4f}")
print(f"‚úÖ F1-score:  {f1:.4f}")
print(f"‚úÖ Precision: {precision:.4f}")
print(f"‚úÖ Recall:    {recall:.4f}")

# ==== Classification Report + Confusion Matrix ====
try:
    class_names = ["Highly Fresh", "Fresh", "Not Fresh"]
    if len(set(all_labels)) > len(class_names):
        raise ValueError("S·ªë l∆∞·ª£ng l·ªõp th·ª±c t·∫ø l·ªõn h∆°n s·ªë l·ªõp ƒë∆∞·ª£c ƒë·ªãnh nghƒ©a.")
    
    print("\nüìä Classification Report:\n")
    print(classification_report(all_labels, all_preds, target_names=class_names))
    
    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    
    # V·∫Ω v√† l∆∞u
    fig, ax = plt.subplots(figsize=(8, 6))  # B·∫°n c√≥ th·ªÉ ch·ªânh k√≠ch th∆∞·ªõc t√πy √Ω
    disp.plot(cmap="Blues", values_format="d", ax=ax)
    plt.title("Confusion Matrix")
    plt.savefig('result/confusion_matrix.png', dpi=300)  # L∆∞u v√†o file PNG
    plt.close()  # ƒê√≥ng plot ƒë·ªÉ kh√¥ng b·ªã ch·ªìng h√¨nh khi v·∫Ω ti·∫øp

    print("‚úÖ ƒê√£ l∆∞u confusion matrix v√†o th∆∞ m·ª•c 'result/'.")

except ValueError as e:
    print(f"‚ö†Ô∏è L·ªói trong vi·ªác t·∫°o b√°o c√°o l·ªõp: {e}")

In [None]:
optimize = "AdamW"

In [None]:
results = {
    "model_name": model_name,
    "optimizer": optimize,
    "lr": LEARNING_RATE,
    "batch_size": BATCH_SIZE,
    "epochs": EPOCHS,
    "epochs_current": epoch,
    "val_best_acc": early_stopper.best_acc * 0.01,
    "val_best_loss": early_stopper.best_loss,
    "accuracy": acc,
    "precision": precision,
    "recall": recall,
    "f1_score": f1,
    "time(s)": total_duration,
    "time_per_epoch(s)": avg_epoch_time
}

In [None]:
results_df = pd.DataFrame([results])
results_df.to_csv(f"result/evaluation_results_{model_name}.csv", index=False)

In [None]:
import json
import os
from datetime import datetime

def save_model_config(save_dir,config):

    # X√≥a c√°c tr∆∞·ªùng None ƒë·ªÉ file json g·ªçn
    config = {k: v for k, v in config.items() if v is not None}

    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"{model_name}_config.json")
    
    with open(save_path, 'w') as f:
        json.dump(config, f, indent=4)
    
    print(f"‚úÖ ƒê√£ l∆∞u file config: {save_path}")


In [None]:
config = {
    "model_name": model_name,
    "input_size": 224,
    "num_classes": 3,
    "batch_size": BATCH_SIZE,
    "learning_rate": LEARNING_RATE,
    "optimizer": optimize,
    "loss_function": "CrossEntropyLoss",
    "weight_decay": 0.02,
    "epoch_trained": epoch,
    "best_val_acc": early_stopper.best_acc * 0.01,
    "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "loss_train": loss_train,
    "loss_val": loss_val,
    "acc_train": acc_train,
    "acc_val": acc_val
}

In [None]:
save_model_config(save_dir="result",config = config)

In [None]:
from IPython.display import FileLink
import shutil
import os

# Gi·∫£ s·ª≠ model_name ƒë√£ ƒë∆∞·ª£c ƒë·ªãnh nghƒ©a
zip_filename = f"result_{model_name}.zip"

# N√©n folder
shutil.make_archive(base_name=zip_filename.replace('.zip', ''), format='zip', root_dir='result')

# T·∫°o link t·∫£i
print("‚úÖ File ƒë√£ s·∫µn s√†ng ƒë·ªÉ t·∫£i:")
display(FileLink(zip_filename))
