<a href="https://colab.research.google.com/github/joshualin24/Lens_Finder/blob/master/Vision_Transformers_for_Strong_Lensing_Parameters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%cd /content/drive/MyDrive/Deep_Cosmos_AI/DeepLense/Strong-Lensing-ViT

/content/drive/.shortcut-targets-by-id/1SOTxsao-uEVVYslV76SRXz66Gw3IntqV/Deep_Cosmos_AI/DeepLense/Strong-Lensing-ViT


In [None]:
import os
import glob
import time
import copy
import numpy as np
import pandas as pd
from skimage.transform import resize

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_paths, use_pretrain=True):
        super(CustomDataset, self).__init__()
        self.use_pretrain = use_pretrain

        self.data_samples = []
        for path in data_paths:
            print(f"Loading `{path}`...")
            self.data_samples.extend(np.load(path, allow_pickle=True))


    def __getitem__(self, index):
        image, parameters = self.data_samples[0]

        if self.use_pretrain:
            image = resize(image, (224, 224))
        image = torch.from_numpy(image).type(torch.float32)
        image = torch.stack([image, image, image], axis=0)
        parameters = torch.from_numpy(parameters).type(torch.float32)
    
        return image, parameters

    def __len__(self):
        return len(self.data_samples)

In [None]:
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(1, num_epochs+1):
        # ==============
        # Model training
        # ==============
        running_loss = 0.0
        running_corrects = 0
        num_images = 0
        model.train() # Set model to training mode
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            num_images += inputs.shape[0]
            running_loss += loss.item() * inputs.size(0)
            # running_corrects += torch.sum(preds == labels.data)
            if batch_idx % 10 == 0:
                print("Batch {}/{} -> Loss: {:.8f}".format(batch_idx, len(train_loader), loss.item()))

        scheduler.step()
        train_loss = running_loss / num_images
        # train_acc = running_corrects.double() / num_images

        # ================
        # Model Validation
        # ================
        running_loss = 0.0
        running_corrects = 0
        num_images = 0
        model.eval() # Set model to evaluate mode
        for inputs, labels in valid_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

            num_images += inputs.shape[0]
            running_loss += loss.item() * inputs.size(0)
            # running_corrects += torch.sum(preds == labels.data)

        val_loss = running_loss / num_images
        # val_acc = running_corrects.double() / num_images

        print("Epoch [{:2d}/{:2d}] Train Loss: {:.8f}| Val Loss: {:.8f}".format(
            epoch, num_epochs, train_loss, val_loss))
        # if val_acc > best_acc:
        #     best_acc = val_acc
        #     best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    # model.load_state_dict(best_model_wts)
    return model

In [None]:
# data_dir = "./datasets/npy_files"
# metadata_path = "./datasets/metadata.csv"
# rqrd_params = [
#     "theta_E", "gamma", "center_x", "center_y", 
#     "e1", "e2", "gamma_ext", "psi_ext", 
#     "lens_light_n_sersic", "lens_light_R_sersic"
# ]

batch_size = 64
num_outputs_params = 10
learing_rate = 1e-4
step_size = 10
num_epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_paths = glob.glob("./datasets/train/*.npy")[:4]
valid_paths = glob.glob("./datasets/validation/*.npy")
train_paths, valid_paths

(['./datasets/train/train_meta-0.npy',
  './datasets/train/train_meta-1.npy',
  './datasets/train/train_meta-2.npy',
  './datasets/train/train_meta-3.npy'],
 ['./datasets/validation/valid_meta.npy'])

In [None]:
train_dataset = CustomDataset(train_paths)
valid_dataset = CustomDataset(valid_paths)



Loading `./datasets/train/train_meta-0.npy`...
Loading `./datasets/train/train_meta-1.npy`...
Loading `./datasets/train/train_meta-2.npy`...
Loading `./datasets/train/train_meta-3.npy`...
Loading `./datasets/validation/valid_meta.npy`...


In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=num_outputs_params)
model = model.to(device)

criterion = nn.MSELoss(reduction="mean")
optimizer = optim.Adam(model.parameters(), lr=learing_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)

In [None]:
trained_model = train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs)

Batch 0/188 -> Loss: 5.04981947
Batch 10/188 -> Loss: 0.02691992
Batch 20/188 -> Loss: 0.00591745
Batch 30/188 -> Loss: 0.00019096
Batch 40/188 -> Loss: 0.00002119
Batch 50/188 -> Loss: 0.00002962
Batch 60/188 -> Loss: 0.00001738
Batch 70/188 -> Loss: 0.00000057
Batch 80/188 -> Loss: 0.00000180
Batch 90/188 -> Loss: 0.00000037
Batch 100/188 -> Loss: 0.00000018
Batch 110/188 -> Loss: 0.00000006
Batch 120/188 -> Loss: 0.00000003
Batch 130/188 -> Loss: 0.00000000
Batch 140/188 -> Loss: 0.00000000
Batch 150/188 -> Loss: 0.00000000
Batch 160/188 -> Loss: 0.00000000
Batch 170/188 -> Loss: 0.00000000
Batch 180/188 -> Loss: 0.00000000
Epoch [ 1/20] Train Loss: 0.09916129| Val Loss: 0.00000382
Batch 0/188 -> Loss: 0.00000000
Batch 10/188 -> Loss: 0.00000000
Batch 20/188 -> Loss: 0.00000000
Batch 30/188 -> Loss: 0.00000000
Batch 40/188 -> Loss: 0.00000000
Batch 50/188 -> Loss: 0.00000000
Batch 60/188 -> Loss: 0.00000000
Batch 70/188 -> Loss: 0.00000000
Batch 80/188 -> Loss: 0.00000000
Batch 90/1

KeyboardInterrupt: ignored