In [1]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import torch
import torch.nn as nn
from tqdm import tqdm
from datasets import ImageDataset
from distortions import *
from models import DistortionBinaryClassifier, IQAEncoder
from torchvision import transforms
from torch.utils.data import DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(model_path, device):
    model = IQAEncoder(feature_dim=128, model_name='resnet50').to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

encoder = load_model('../models/resnet50_128_out.pth', device)
model = DistortionBinaryClassifier(encoder).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

distortions = [Clean(), LensBlur(), MotionBlur(), GaussianNoise(), Overexposure(), Underexposure(), Compression(), Ghosting(), Aliasing()]
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)), # 224 or 384
    transforms.ToTensor(),
])

train_image_folders = [f"../data/video_frames_{i}" for i in range(1, 13) if i != 5]
image_paths = []
for folder in train_image_folders:
    if not os.path.isdir(folder):
        print(f"Warning: folder {folder} not found, skipping.")
        continue
    image_paths.extend(
        [os.path.join(folder, fname)
         for fname in os.listdir(folder)
         if fname.lower().endswith(('.jpg', '.png'))]
    )
dataset = ImageDataset(image_paths, distortions=distortions, transform=transform, binary_labels=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
print(f"Train Dataset length: {len(dataset)}")

eval_image_folders = "../data/video_frames_5" 
image_paths = [os.path.join(eval_image_folders, fname) for fname in os.listdir(eval_image_folders) if fname.endswith(('.jpg', '.png'))]
eval_dataset = ImageDataset(image_paths, distortions=distortions, transform=transform, binary_labels=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=64, shuffle=False, num_workers=4)
print(f"Eval Dataset length: {len(eval_dataset)}")

epochs = 100
best_loss = float('inf')
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for imgs, labels in tqdm(dataloader, desc=f"Binary Epoch {epoch+1}", leave=False):
        imgs = imgs.to(device)
        binary_labels = torch.tensor([0.0 if l == 'Clean' else 1.0 for l in labels], dtype=torch.float32, device=device).unsqueeze(1)  # shape: [B, 1]

        logits = model(imgs)
        loss = loss_fn(logits, binary_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(dataloader):.4f}")

    # Validation
    if ((epoch+1) % 1 == 0) or (epoch == 0):
        accuracy = 0.0
        eval_loss = 0.0
        model.eval()
        with torch.no_grad():
            for imgs, labels in tqdm(eval_dataloader, desc=f"Eval Epoch {epoch+1}", leave=False):
                imgs = imgs.to(device)
                binary_labels = torch.tensor([0.0 if l == 'Clean' else 1.0 for l in labels], dtype=torch.float32, device=device).unsqueeze(1)
                logits = model(imgs)
                preds = torch.sigmoid(logits) > 0.5
                accuracy += (preds == binary_labels).float().mean().item()
                eval_loss = loss_fn(logits, binary_labels).item()
        
        eval_loss /= len(eval_dataloader)
        accuracy /= len(eval_dataloader)
        print(f"Eval Accuracy: {accuracy:.4f}, Eval Loss: {eval_loss:.4f}")
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), f'../models/best_binary_classifier.pth')
            print(f"Saved Best Model with Loss: {best_loss / len(dataloader):.4f}")
        else:
            print(f"Early stopping at epoch {epoch + 1} with loss {eval_loss:.4f} (best: {best_loss:.4f})")
            break

        


Train Dataset length: 76625
Eval Dataset length: 4045


                                                                   

Epoch 1, Loss: 0.3719


                                                             

Eval Accuracy: 0.8563, Eval Loss: 0.0201
Saved Best Model with Loss: 0.3719


                                                                   

Epoch 2, Loss: 0.3241


                                                             

Eval Accuracy: 0.8277, Eval Loss: 0.0083
Saved Best Model with Loss: 0.3241


                                                                   

Epoch 3, Loss: 0.3099


                                                             

Eval Accuracy: 0.8629, Eval Loss: 0.0098
Saved Best Model with Loss: 0.3099


                                                                   

Epoch 4, Loss: 0.2978


                                                             

Eval Accuracy: 0.8964, Eval Loss: 0.0165
Saved Best Model with Loss: 0.2978


                                                                   

Epoch 5, Loss: 0.2895


                                                             

Eval Accuracy: 0.8788, Eval Loss: 0.0202
Saved Best Model with Loss: 0.2895


                                                                   

Epoch 6, Loss: 0.2802


                                                             

Eval Accuracy: 0.8710, Eval Loss: 0.0258
Saved Best Model with Loss: 0.2802


                                                                   

Epoch 7, Loss: 0.2803


                                                             

Eval Accuracy: 0.8885, Eval Loss: 0.0176
Early stopping at epoch 7 with loss 0.0176 (best: 335.7181)


