In [1]:
# ============================================================
# STEP 0: Imports
# ============================================================
import os
import shutil
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18
from tqdm import tqdm

In [2]:
# ============================================================
# STEP 1: Merge all RAW images into ONE folder (no subfolders)
# ============================================================
import os
import shutil

dataset_path = "/kaggle/input/fish-classification-dataset/Fish Data"
merged_path = "/kaggle/working/fish_merged_raw"
os.makedirs(merged_path, exist_ok=True)

for cls in os.listdir(dataset_path):
    cls_path = os.path.join(dataset_path, cls)
    if not os.path.isdir(cls_path):
        continue
    
    raw_folder = None
    for name in os.listdir(cls_path):
        if name.lower() == "raw" or name.lower() == "raw data":
            raw_folder = os.path.join(cls_path, name)
            break
    
    if raw_folder is None:
        print(f"⚠ No raw data found for class {cls}")
        continue
    
    for img in os.listdir(raw_folder):
        src = os.path.join(raw_folder, img)
        dst = os.path.join(merged_path, f"{cls}_{img}")  # prefix with class name to avoid name clashes
        shutil.copy(src, dst)

print("✅ Merging complete. All raw images saved at:", merged_path)
print("Total images:", len(os.listdir(merged_path)))


✅ Merging complete. All raw images saved at: /kaggle/working/fish_merged_raw
Total images: 26950


In [3]:
# ============================================================
# STEP 2: BYOL Transform
# ============================================================
class BYOLTransform:
    def __init__(self):
        self.base_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    
    def __call__(self, x):
        return self.base_transform(x), self.base_transform(x)



In [4]:
from PIL import Image
from torch.utils.data import Dataset
import glob

class FlatImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.files = glob.glob(os.path.join(root_dir, "*"))
        self.transform = transform
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path).convert("RGB")
        
        if self.transform:
            img1, img2 = self.transform(img)
            return img1, img2
        
        return img

# Now replace train_dataset and train_loader:
train_dataset = FlatImageDataset(root_dir=merged_path, transform=BYOLTransform())
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

print("Total images loaded:", len(train_dataset))


Total images loaded: 26950


In [5]:
# ============================================================
# STEP 4: BYOL Model Definition
# ============================================================
class BYOL(nn.Module):
    def __init__(self, encoder, projection_dim=256):
        super().__init__()
        self.online_encoder = encoder
        self.online_projector = nn.Sequential(
            nn.Linear(512, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )
        
        self.target_encoder = resnet18(pretrained=False)
        self.target_encoder.fc = nn.Identity()
        self.target_projector = nn.Sequential(
            nn.Linear(512, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )
        
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
        
        self.criterion = nn.MSELoss()
    
    def forward(self, x1, x2):
        o1 = self.online_projector(self.online_encoder(x1))
        o2 = self.online_projector(self.online_encoder(x2))
        
        with torch.no_grad():
            t1 = self.target_projector(self.target_encoder(x1))
            t2 = self.target_projector(self.target_encoder(x2))
        
        loss = self.criterion(o1, t2) + self.criterion(o2, t1)
        return loss



In [6]:
# ============================================================
# STEP 5: Training Setup
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = resnet18(pretrained=False)
encoder.fc = nn.Identity()

model = BYOL(encoder).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scaler = torch.cuda.amp.GradScaler()

EPOCHS = 30



  scaler = torch.cuda.amp.GradScaler()


In [None]:
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{EPOCHS}]")
    for img1, img2 in loop:  # <-- no label unpacking here
        img1, img2 = img1.to(device), img2.to(device)
        
        with torch.cuda.amp.autocast():
            loss = model(img1, img2)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Loss: {avg_loss:.4f}")


  with torch.cuda.amp.autocast():
Epoch [1/30]: 100%|██████████| 211/211 [04:07<00:00,  1.17s/it, loss=0.23] 


Epoch 1 Loss: 0.2504


Epoch [2/30]: 100%|██████████| 211/211 [04:06<00:00,  1.17s/it, loss=0.233]


Epoch 2 Loss: 0.2342


Epoch [3/30]: 100%|██████████| 211/211 [04:07<00:00,  1.17s/it, loss=0.255]


Epoch 3 Loss: 0.2329


Epoch [4/30]: 100%|██████████| 211/211 [04:03<00:00,  1.15s/it, loss=0.235]


Epoch 4 Loss: 0.2319


Epoch [5/30]: 100%|██████████| 211/211 [04:00<00:00,  1.14s/it, loss=0.225]


Epoch 5 Loss: 0.2306


Epoch [6/30]: 100%|██████████| 211/211 [03:58<00:00,  1.13s/it, loss=0.238]


Epoch 6 Loss: 0.2297


Epoch [7/30]: 100%|██████████| 211/211 [03:58<00:00,  1.13s/it, loss=0.224]


Epoch 7 Loss: 0.2299


Epoch [8/30]:  14%|█▍        | 30/211 [00:37<03:23,  1.13s/it, loss=0.209]

In [None]:
# ============================================================
# STEP 7: Save the Encoder
# ============================================================
torch.save(model.online_encoder.state_dict(), "/kaggle/working/byol_encoder.pth")
print("✅ Encoder saved to /kaggle/working/byol_encoder.pth")

In [None]:
import os
import random
import shutil

original_path = "/kaggle/input/fish-classification-dataset/Fish Data"
split_root = "/kaggle/working/fish_split"

def create_train_val_test_split(src_folder, dest_folder, train_ratio=0.7, val_ratio=0.15):
    os.makedirs(dest_folder, exist_ok=True)
    splits = ['train', 'val', 'test']
    for s in splits:
        os.makedirs(os.path.join(dest_folder, s), exist_ok=True)

    for cls in os.listdir(src_folder):
        cls_path = os.path.join(src_folder, cls)
        if not os.path.isdir(cls_path):
            continue
        
        # Find raw or raw data folder
        raw_folder = None
        for name in os.listdir(cls_path):
            if name.lower() in ("raw", "raw data"):
                raw_folder = os.path.join(cls_path, name)
                break
        if raw_folder is None:
            continue

        imgs = [f for f in os.listdir(raw_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        random.shuffle(imgs)

        n = len(imgs)
        n_train = int(n * train_ratio)
        n_val = int(n * val_ratio)

        splits_data = {
            'train': imgs[:n_train],
            'val': imgs[n_train:n_train + n_val],
            'test': imgs[n_train + n_val:]
        }

        for split_name, img_list in splits_data.items():
            dest_cls_dir = os.path.join(dest_folder, split_name, cls)
            os.makedirs(dest_cls_dir, exist_ok=True)
            for img_name in img_list:
                src_img_path = os.path.join(raw_folder, img_name)
                dst_img_path = os.path.join(dest_cls_dir, img_name)
                shutil.copy(src_img_path, dst_img_path)

    print(f"Train/val/test split created at {dest_folder}")

create_train_val_test_split(original_path, split_root)


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

eval_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_ds = datasets.ImageFolder(root=os.path.join(split_root, "train"), transform=eval_transform)
val_ds = datasets.ImageFolder(root=os.path.join(split_root, "val"), transform=eval_transform)
test_ds = datasets.ImageFolder(root=os.path.join(split_root, "test"), transform=eval_transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=4)


In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = resnet18(pretrained=False)
encoder.fc = nn.Identity()
encoder.load_state_dict(torch.load("/kaggle/working/byol_encoder.pth"))
encoder = encoder.to(device)
encoder.eval()

for param in encoder.parameters():
    param.requires_grad = False


In [None]:
num_classes = len(train_ds.classes)
linear_clf = nn.Linear(512, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_clf.parameters(), lr=1e-3)

epochs = 30

for epoch in range(epochs):
    linear_clf.train()
    total_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        with torch.no_grad():
            feats = encoder(imgs)

        outputs = linear_clf(feats)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    # Validation accuracy
    linear_clf.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            feats = encoder(imgs)
            outputs = linear_clf(feats)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    print(f"Validation Accuracy: {correct/total:.4f}")


In [None]:
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

linear_clf.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        feats = encoder(imgs)
        outputs = linear_clf(feats)
        preds = outputs.argmax(dim=1)
        y_true.extend(labels.cpu().tolist())
        y_pred.extend(preds.cpu().tolist())

test_acc = accuracy_score(y_true, y_pred)
print(f"Test Accuracy: {test_acc:.4f}")

cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=train_ds.classes)
fig, ax = plt.subplots(figsize=(10,10))
disp.plot(ax=ax, xticks_rotation='vertical')
plt.title("Confusion Matrix on Test Set")
plt.show()
