# **Imports and required installation**

In [None]:
!pip install pytorch-lightning wandb

In [None]:
import os
import zipfile
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset

import torchvision.models as models
from torchvision import transforms

import torchmetrics

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

import wandb



In [None]:
wandb.login()


In [None]:
input_size = (400,400)
number_of_classes = 5
root_dir_for_dataset = '/tmp/dino'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# **Dataset preparation**

In [None]:
image_zip = '/content/drive/MyDrive/dino.zip'
zip_ref   = zipfile.ZipFile(image_zip, 'r')
zip_ref.extractall('/tmp/')
zip_ref.close()

In [None]:
class My_Dataset(Dataset):
  def __init__(self,root_dir,transforms = None):
    self.root_dir = root_dir
    self.classes = os.listdir(self.root_dir)
    self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
    self.image_files, self.labels = self.load_image_files()
    self.transforms = transforms

  def __len__(self):
    return len(self.image_files)

  def __getitem__(self, idx):
      image_path = self.image_files[idx]
      label = self.labels[idx]

      image = Image.open(image_path).convert("RGB")
      if self.transforms:
           image = self.transforms(image)

      return image, label

  def load_image_files(self):
        image_files = []
        labels = []
        for class_name in self.classes:
            class_dir = os.path.join(self.root_dir, class_name)
            if os.path.isdir(class_dir):
                images = os.listdir(class_dir)
                for image_name in images:
                    image_path = os.path.join(class_dir, image_name)
                    image_files.append(image_path)
                    labels.append(self.class_to_idx[class_name])
        return image_files, labels



In [None]:
class DataModule(pl.LightningDataModule):

  def __init__(self,root_dir,batch_size):
    super(DataModule,self).__init__()
    self.root_dir= root_dir
    self.batch_size = batch_size
    self.transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4131, 0.3576, 0.2830], std=[0.2309, 0.2194, 0.2037])
                ])

  def setup(self,stage=None):
    dataset = My_Dataset(root_dir=root_dir_for_dataset, transforms=self.transform)
    self.train_set, self.val_set ,self.test_set = torch.utils.data.random_split(dataset,[1700,423,150])

  def train_dataloader(self):
      return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True,drop_last=True,num_workers=7)

  def test_dataloader(self):
      return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False,drop_last=True,num_workers=7)

  def val_dataloader(self):
      return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False,drop_last=True,num_workers=7)




# **Model**

In [None]:
class Model(pl.LightningModule):

    def __init__(self, model,input_shape, num_classes, learning_rate, transfer=True):
        super(Model,self).__init__()

        # log hyperparameters

        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes
        self.pretrained_model = model
        self.feature_extractor = getattr(models, self.pretrained_model)(pretrained=transfer)


        if transfer:
            # layers are frozen by using eval()
            self.feature_extractor.eval()
            # freeze params
            for param in self.feature_extractor.parameters():
                param.requires_grad = False

        n_sizes = self._get_conv_output(input_shape)


        self.classifier = nn.Linear(n_sizes, num_classes)


        self.loss_function = nn.CrossEntropyLoss()
        self.accuracy =  torchmetrics.Accuracy(task="multiclass",num_classes=self.num_classes)

    # returns the size of the output tensor going into the Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))


        output_feat = self._forward_features(tmp_input)
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size

    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = self.feature_extractor(x)
        return x

    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = self.classifier(x)

       return x

    def training_step(self,batch,batch_idx):
        loss, scores, labels = self._common_step(batch,batch_idx)
        accuracy = self.accuracy(scores,labels)
        self.log_dict({"train_loss": loss , "train_acc": accuracy}, prog_bar=True,on_epoch=True)
        return loss

    def test_step(self,batch,batch_idx):
        loss, scores, labels = self._common_step(batch,batch_idx)
        accuracy = self.accuracy(scores,labels)
        self.log_dict({"test_loss": loss , "test_acc": accuracy}, prog_bar=True,on_step=False,on_epoch=True)
        return loss

    def validation_step(self,batch,batch_idx):
        loss, scores, labels = self._common_step(batch,batch_idx)
        accuracy = self.accuracy(scores,labels)
        self.log_dict({"val_loss": loss , "val_acc": accuracy}, prog_bar=True,on_step=False,on_epoch=True)
        return loss

    def _common_step(self,batch,batch_idx):
        inputs,labels = batch
        scores = self.forward(inputs)
        loss = self.loss_function(scores,labels)
        return loss, scores, labels

    def configure_optimizers(self) :
        return optim.Adam(params=self.parameters(),lr=self.learning_rate)


# **Finding best parameters**

In [None]:
sweep_config = {'method': 'random',
 'metric': {'goal': 'maximize', 'name': 'test_acc'},
 'parameters': {'batch_size': {'values': [4,8,16]
                               },

                'learning_rate': {'distribution': 'log_uniform_values',
                                  'max': 1e-2,
                                  'min': 1e-5
                                  },

                'models': {'values': ['resnet18','vgg16','squeezenet1_0','inception_v3','googlenet','mobilenet_v2','densenet161']}
                }
                }

sweep_id = wandb.sweep(sweep_config, project="classification-pytorch-lightning")


Create sweep with ID: dtmrenq3
Sweep URL: https://wandb.ai/mirokery/classification-pytorch-lightning/sweeps/dtmrenq3


# **Training models**

In [None]:
def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config,project="classification-pytorch-lightning",name="training_with_sweeps_models"):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        pretrained_model = wandb.config.models
        batch_size=wandb.config.batch_size
        lr=  wandb.config.learning_rate
        dm = DataModule(root_dir=root_dir_for_dataset,batch_size=batch_size)
        dm.setup()

        model= Model(model=pretrained_model,input_shape=(3,400,400),num_classes=number_of_classes,learning_rate=lr)
        wandb_logger = WandbLogger(project="classification-pytorch-lightning",name="training_with_sweeps_models",log_model=True)
        trainer = pl.Trainer(logger=wandb_logger, precision=16,max_epochs=5,
                         accelerator="gpu")
        trainer.fit(model, dm)
        trainer.test(model, dm)

        wandb.finish()




In [None]:
wandb.agent(sweep_id, train, count=100,project="classification-pytorch-lightning")

# **download and use of best model**

In [None]:

run = wandb.init()
artifact = run.use_artifact('mirokery/classification-pytorch-lightning/model-n6njivt0:v0', type='model')
artifact_dir = artifact.download()


In [None]:
print(artifact_dir)
ckpt_path = os.path.join(artifact_dir,'model.ckpt')
model = Model.load_from_checkpoint(ckpt_path)
model.eval()

class_to_idx ={'para': 0, 'spino': 1, 'stego': 2, 'trex': 3, 'velo': 4}
transform = transforms.Compose([
    transforms.Resize((400,400)),  # Resize to a fixed size
    transforms.ToTensor(),  # Convert image to tensor
     transforms.Normalize(mean=[0.4131, 0.3576, 0.2830], std=[0.2309, 0.2194, 0.2037])  # Normalize image
])

image = Image.open("/content/drive/MyDrive/Test Images/test/DSC04764.JPG").convert('RGB')
input_tensor = transform(image)

input_batch = input_tensor.unsqueeze(0)
input_batch = input_batch.type(torch.cuda.FloatTensor)
print(input_batch.shape)
with torch.no_grad():
        output = model(input_batch)


print(output.shape)
print({i for i in class_to_idx if class_to_idx[i]==int(torch.argmax(output))})