In [6]:
from models.resnet_simclr import ResNetSimCLR
import torch
from torchvision import transforms
from logger import get_logger
from config import *
from data_aug.aug_loader import AugmentedImageDataset
from data_aug.image_dataset import ImageDataset
from torch.utils.data import DataLoader
from train_pretext import train_pretext
from train_fine_tune import train_fine_tune

logger = get_logger()

In [7]:
def get_aug_loader():
    IMG_NORM = dict(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    resizer = transforms.Compose([
        transforms.Resize(IMG_SIZE),  # Resize Image
        transforms.ToTensor(),  # Convert Image to Tensor
        transforms.Normalize(**IMG_NORM)  # Normalization
    ])

    aug_transform = transforms.Compose([
        transforms.RandomResizedCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomRotation(degrees=15),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
        transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
    ])


    unlabeled_dataset = ImageDataset(root=DATA_PATH, force_download=False, unlabeled=True, transform=resizer)
    unlabeled_loader = AugmentedImageDataset(unlabeled_dataset, transform=aug_transform)

    return DataLoader(unlabeled_loader, batch_size=BATCH_SIZE_PRETEXT, shuffle=True), unlabeled_dataset.classes

In [8]:
def get_train_test_loader():
    IMG_NORM = dict(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    resizer = transforms.Compose([
        transforms.Resize(IMG_SIZE),  # Resize Image
        transforms.ToTensor(),  # Convert Image to Tensor
        transforms.Normalize(**IMG_NORM)  # Normalization
    ])

    train_dataset = ImageDataset(root=DATA_PATH, force_download=False, train=True, transform=resizer)
    test_dataset = ImageDataset(root=DATA_PATH, force_download=False, valid=True, transform=resizer)
    return DataLoader(train_dataset, batch_size=BATCH_SIZE_FINE_TUNE, shuffle=True), DataLoader(test_dataset, batch_size=BATCH_SIZE_FINE_TUNE, shuffle=False)

In [9]:
def main(device):
    unlabeled_loader, data_classes = get_aug_loader()
    model = ResNetSimCLR(out_dim=len(data_classes))
    model.to(device)
    base_epoch = 0
    try:
        model.backbone.load_state_dict(torch.load(f"models/snapshot/resnet_simclr_epoch_{BASE_EPOCH}.pth"))
        logger.info(f"Pre-trained model epoch{BASE_EPOCH} loaded successfully.")
        base_epoch = BASE_EPOCH
    except FileNotFoundError:
        logger.info("No pre-trained model found, starting from scratch.")

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(unlabeled_loader),
                                                           eta_min=0, last_epoch=-1)
    train_pretext(model=model.backbone, loader=unlabeled_loader, optimizer=optimizer, scheduler=scheduler,
                  device=device, base_epoch=base_epoch, total_epochs=EPOCHS_PRETEXT)

    torch.save(model.state_dict(), "models/snapshot/resnet_simclr_final.pth")

    train_loader, test_loader = get_train_test_loader()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE_FINE_TUNE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)
    train_fine_tune(model=model, train_loader= train_loader, test_loader= test_loader, optimizer=optimizer, device = device, total_epochs= EPOCHS_FINE_TUNE, scheduler=scheduler)
    torch.save(model.state_dict(), "models/snapshot/complete_resnet_simclr_final.pth")

In [10]:
DEVICE_NUM = 0
device = torch.device(f"cuda:{DEVICE_NUM}" if torch.cuda.is_available() else "cpu")
print(device)
main(device)

cuda:0
INFO: Dataset archive found in the root directory. Skipping download.




[INFO] No pre-trained model found, starting from scratch.
Training Pretext Task - Epoch 1/150


Epoch 1/150:   6%|▌         | 4/71 [00:14<03:56,  3.53s/it]


KeyboardInterrupt: 