<a href="https://colab.research.google.com/github/gugi200/final_project/blob/main/comp_diff_models_custom_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# importing libraries

In [None]:
!pip install wandb -qU
!wandb login 3014974e724f01c4d63f956fa13fd7f0463e16d4
!pip install torchmetrics
!pip install mlxtend>=0.19.0

!pip list | grep mlx


#
#   Michael Gugala
#   02/12/2023
#   Image recognition
#   Master 4th year project
#   Univeristy of Bristol
#

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

import torchvision
from torchvision import datasets#
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchmetrics
from torchvision.models import resnet50, ResNet50_Weights
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix
from sklearn.utils import Bunch

from PIL import Image

import requests
import random
import shutil
import zipfile
from pathlib import Path
from io import BytesIO, StringIO
import os

import wandb
import cv2
from timeit import default_timer as timer
from tqdm.auto import tqdm

# check imports
print(torch.__version__)
print(torchvision.__version__)

#agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# downloading data and spittiling into raw train and test subsets

In [None]:
colabPath = Path("custom_dataset")

#  Create a dir
if colabPath.is_dir():
    print('directory already exists')
else:
    colabPath.mkdir(parents=True, exist_ok=True)

# download zipped data
    github_path = "https://github.com/gugi200/final_project/raw/main/data.zip"
    # file = "dataset_pressure_sensor.zip"
    file = 'data.zip'
with open(colabPath / file, "wb") as f:
    request = requests.get(github_path)
    f.write(request.content)


# unzip the data
with zipfile.ZipFile(colabPath / file, "r") as f:
    f.extractall(colabPath)




TRAIN_RATIO = 0.75
dirs = os.listdir(colabPath)


#  Create a dir for train and test data
extendedTrain = Path("train_raw")
extendedTest = Path("test_raw")
if extendedTrain.is_dir():
    print('directory already exists')
else:
    extendedTrain.mkdir(parents=True, exist_ok=True)
    extendedTest.mkdir(parents=True, exist_ok=True)
    for dir in dirs:
        path = extendedTrain / dir
        path.mkdir(parents=True, exist_ok=True)
    for dir in dirs:
        path = extendedTest / dir
        path.mkdir(parents=True, exist_ok=True)

for dir in dirs:
    files = os.listdir(extendedDataPath / dir)
    length = int(TRAIN_RATIO*len(files))
    random.shuffle(files)

    train_set = files[:length]
    test_set = files[length:]

    for data in train_set:
        shutil.copy(extendedDataPath / dir / data, extendedTrain / dir / data)

    for data in test_set:
        shutil.copy(extendedDataPath / dir / data, extendedTest / dir / data)

l = 0
for dir in dirs:
    l += len(os.listdir(extendedTrain/dir))
    print(dir, len(os.listdir(extendedTrain/dir)))
print(l)

l = 0
for dir in dirs:
    l += len(os.listdir(extendedTest/dir))
    print(dir, len(os.listdir(extendedTest/dir)))
print(l)

# Library

In [None]:


# create data set from a custom data
def create_dataset(path, batchsize, preprocess=None, mean=None, std=None):
    '''
    input:
    path - path to the folder with the data
           eg for train - "data/FashionMNIST/train"
    batchsize - eg 32
    mean (optional)- for normalization eg. [0.25, 0.25, 0.25]
    std (optional)- for nortmalization eg [0.1, 0.1, 0.1]

    returns:
    dataloader with image size of 224
    class_names
    '''
    if not preprocess:

        if mean:
            preprocess = transforms.Compose([

                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ])
        else:
            preprocess = transforms.Compose([

                transforms.Resize(size=(224, 224)),
                transforms.ToTensor()
            ])


    data = datasets.ImageFolder(root=Path(path),
                                    transform=preprocess, # tranform for the data
                                    target_transform=None) # transform for label
    dataloader = DataLoader(dataset=data,
                                batch_size=batchsize,
                                shuffle=True) #  shuffling to remove order
    class_names = data.classes
    return dataloader, class_names, data.targets

# visdualize 9 random images in a batch
def visualise_data(dataloader, class_names, batchsize):
    '''
    input dataloader
    class names

    displays 9 random images in a batch and their labels
    '''
    train_features_batch, train_labels_batch = next(iter(dataloader))
    print("length of data: ", len(train_features_batch), 'length of labels: ', len(train_labels_batch))
    # display random datapoints
    fig = plt.figure(figsize=(9, 9))
    rows, cols  = 3, 3
    for pic in range(1, 1+rows*cols):
        rand_int = np.random.randint(0, batchsize)
        img = train_features_batch[rand_int]
        img_RGB = img.permute([1, 2, 0]).numpy()
        fig.add_subplot(rows, cols, pic)
        plt.imshow(img_RGB.squeeze())
        plt.axis(False)
        plt.title(class_names[train_labels_batch[rand_int]])


# create optimizer
def create_optiimizer(model, optimizer, lr):
    '''
    inputs:
    model - CNN network
    optimizer - "adam" or "sgd"
    lr - learning rate eg. 0.01

    '''
    if optimizer=='adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr
                                     )
    elif optimizer=='sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                     lr=lr,
                                    momentum=0.9
                                     )

    return optimizer


def get_lossFn():
    '''
        returns CrossEntropyLoss function
    '''

    return nn.CrossEntropyLoss()

# create train step
def train_step(model, metric, loss_fn, optimizer,
               data_loader, device, debug=False, wnb=True):
    '''
    model - CNN network
    metric - metric to calculate accuracy
    loss_fn - loss function
    optimizer - optimizer to be applied
    data_loader - dataloader
    device - decide for the model to train
    debug (optional)- if True prints average loss and metric of the batch

    returns
    train_loss - average loss of the batch
    train_acc - average metric score of the batch

    The function saves the metric score and loss of each iteration in WandB

    '''
    train_loss, train_acc = 0, 0

    model.train()
    for batch, (X, y) in enumerate(data_loader):
        # put data on the device
        X, y = X.to(device), y.to(device)

        #forward pass, return raw logits
        y_pred = model(X)

        #loss
        loss = loss_fn(y_pred, y)
        #accuracy
        acc = metric(y, torch.argmax(y_pred, dim=1))

        train_loss += loss # accumulate train loss
        train_acc += acc # accumulate train accuracy

        # zero grad
        optimizer.zero_grad()

        # loss bacward
        loss.backward()

        #optimizer step
        optimizer.step()

        # log in wandb
        if wnb:
            wandb.log({"loss": loss,
                       'accuracy': acc})

    # device total loss and accuracy by length of train dataloader
    train_loss /= len(data_loader)
    train_acc /= len(data_loader)
    if debug:
        print(f'Train loss: {train_loss:.4f}, Train acc: {train_acc*100:0.4f}%')

    return train_loss, train_acc


# create test step
def test_step(model, metric, loss_fn, data_loader, device, debug=False, wnb=True):
    '''
    model - CNN network
    metric - metric to calculate accuracy
    loss_fn - loss function
    data_loader - dataloader
    device - decide for the model to train
    debug (optional)- if True prints average loss and metric of the batch

    returns
    test_loss - average loss of the batch
    test_acc - average metric score of the batch

    The function saves the metric score and loss of each iteration in WandB

    '''
    test_loss, test_acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for X_test, y_test in data_loader:
            X_test, y_test = X_test.to(device), y_test.to(device)
            #1 forward pass
            test_pred = model(X_test)

            # calculate loss
            loss = loss_fn(test_pred, y_test)
            test_loss += loss

            #accuracy
            acc = metric(y_test, test_pred.argmax(dim=1))
            test_acc += acc

            if wnb:
                wandb.log({"test loss": loss,
                           'test accuracy': acc})

        # Calculate the test loss average batch
        test_loss /= len(data_loader)

        # acc per bactch
        test_acc /= len(data_loader)

        # Print out what's happening
        if debug:
            print(f'Test loss: {test_loss:.4f}  |  Test acc: {test_acc*100:.4f}%')

        return test_loss, test_acc

# create evaluation loop
def eval_model(model: torch.nn.Module,
                data_loader: torch.utils.data.DataLoader,
                loss_fn: torch.nn.Module,
                accuracy_fn,
               device):
    loss, acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for X, y in tqdm(data_loader):
            X, y = X.to(device), y.to(device)
            y_pred = model(X)

            #accumulate the loss and acc
            loss += loss_fn(y_pred, y)
            acc += accuracy_fn(y, y_pred.argmax(dim=1))

        # ave loss and acc
        loss /= len(data_loader)
        acc /= len(data_loader)
    return {"model_name": model.__class__.__name__, # only works if a model was created with a class
            "model_loss": loss.item(),
            "model_acc": acc.item()*100}




def visualize_preds(model, dataloader, class_names, batchsize):
    plt.figure(figsize=(9, 9))
    nrows = 3
    ncols = 3
    model = model.cpu()

    dataL_len = len(dataloader)
    data = iter(dataloader)
    for i in range(3):
        model.eval()
        with torch.inference_mode():

            X, y = next(data)
            X, y = X.cpu(), y.cpu()
            for j in range(3):
                randint = np.random.randint(0, batchsize)
                X_sample, y_sample = X[randint], y[randint]
                pred_logit = model(X_sample.unsqueeze(dim=0))

                pred_prob = pred_logit.argmax(dim=1)


                plt.subplot(nrows, ncols, (3*i)+j+1);
                plt.imshow(X_sample.squeeze().permute([1, 2, 0]), cmap='gray');

                #find pred_label in text form
                pred_label = class_names[pred_prob];

                # find truth label
                truth_label = class_names[y_sample];

                title_text = f'Pred: {pred_label}  \n  Truth: {truth_label}'

                if pred_label==truth_label:
                    plt.title(title_text, fontsize=10, c='g');
                else:
                    plt.title(title_text, fontsize=10, c='r');
                plt.axis(False)
                plt.tight_layout()



def plot_decision_matrix(class_names, y_pred_tensor, targets):
    # setup confusion matrix
    confmat = ConfusionMatrix(num_classes=len(class_names), task='multiclass')
    confmat_tensor = confmat(preds=y_pred_tensor,
                            target=targets)

    # plot consufionmatrix
    fig, ax = plot_confusion_matrix(
        conf_mat=confmat_tensor.numpy(),
        class_names=class_names,
        figsize=(10, 7)
    )



def make_predictions_dataloader(model, dataloader, device, class_names):
    preds = []
    target = []
    model = model.to(device)
    model.eval()
    test_acc = 0
    metric = torchmetrics.classification.Accuracy(
        task="multiclass",
        num_classes=len(class_names)
    ).to(device)
    with torch.inference_mode():

        for X_test, y_test in tqdm(dataloader):
            X_test, y_test = X_test.to(device), y_test.to(device)
            batch_pred = model(X_test)
            batch_pred = batch_pred.cpu()
            y_test = y_test.cpu()
            preds.append(np.array(batch_pred.argmax(dim=1)))
            target.append(y_test)

            acc = metric(y_test, batch_pred.argmax(dim=1))
            test_acc += acc

        # acc per bactch
        test_acc /= len(dataloader)


    return np.concatenate(preds), np.concatenate(target), test_acc


def make_predictions(model, data, device):
    model.eval()
    data = data.to(device)
    model = model.to(device)
    with torch.inference_mode():
        y_preds = model(data)
    return y_preds.cpu()


def dataloader_to_numpy(dataloader):
    for i, (data, target) in enumerate(dataloader):
        if i==0:
            data_numpy = data.numpy()
            target_numpy = target.numpy()
        else:
            data_numpy = np.append(data_numpy, data.numpy(), axis=0)
            target_numpy = np.append(target_numpy, target.numpy(), axis=0)
    return data_numpy, target_numpy


def get_datalodaer(batchsize, train_path, test_path):

    train_dataloader, class_names, _ = create_dataset(
                                        path=train_path,
                                        batchsize=batchsize,
                                    mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]
                                    )
    test_dataloader, _, _ = create_dataset(
                                path=test_path,
                                batchsize=batchsize,
                                    mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]
                                    )
    return train_dataloader, test_dataloader, class_names


def train_test_loop(config, model, train_dataloader, test_dataloader,
                    class_names, device):
    loss_fn = get_lossFn()
    optimizer = create_optiimizer(model=model,
                                    optimizer=config.optimizer,
                                    lr=config.learning_rate
    )
    metric = torchmetrics.classification.Accuracy(
        task="multiclass",
        num_classes=len(class_names)
    ).to(device)
    train_time_start = timer()
    for epoch in range(config.epochs):
        ave_batch_loss, ave_batch_metric = train_step(
            model=model,
            metric=metric,
            loss_fn=loss_fn,
            optimizer=optimizer,
            data_loader=train_dataloader,
            device=device,
            debug=True
        )
        ave_batch_loss, ave_batch_metric = test_step(
            model=model,
            metric=metric,
            loss_fn=loss_fn,
            data_loader=test_dataloader,
            device=device,
            debug=True
        )
        wandb.log({"average train batch loss": ave_batch_loss,
                    "average train batch metric": ave_batch_metric,
                    "average test batch loss": ave_batch_loss,
                    "average test batch metric": ave_batch_metric,
                    "epoch": epoch
                    })
    train_time_end = timer()
    wandb.log({"train time": train_time_end - train_time_start})


def train_test_loop_nonpipe(model, train_dataloader, test_dataloader, optim, lr,
                    epochs, class_names, device):
    loss_fn = get_lossFn()
    optimizer = create_optiimizer(model=model,
                                    optimizer=optim,
                                    lr=lr
    )
    metric = torchmetrics.classification.Accuracy(
        task="multiclass",
        num_classes=len(class_names)
    ).to(device)
    train_time_start = timer()
    for epoch in range(epochs):
        ave_batch_loss, ave_batch_metric = train_step(
            model=model,
            metric=metric,
            loss_fn=loss_fn,
            optimizer=optimizer,
            data_loader=train_dataloader,
            device=device,
            debug=True
        )
        ave_batch_loss, ave_batch_metric = test_step(
            model=model,
            metric=metric,
            loss_fn=loss_fn,
            data_loader=test_dataloader,
            device=device,
            debug=True
        )
        wandb.log({"average train batch loss": ave_batch_loss,
                    "average train batch metric": ave_batch_metric,
                    "average test batch loss": ave_batch_loss,
                    "average test batch metric": ave_batch_metric,
                    "epoch": epoch
                    })
    train_time_end = timer()
    wandb.log({"train time": train_time_end - train_time_start})


# models

## setup wandb

In [None]:





torch.manual_seed(42)
torch.cuda.manual_seed(42)


sweep_config = {
    'method': 'grid'
    }
metric = {
    'name': 'loss',
    'goal': 'minimize'
    }
sweep_config['metric'] = metric

parameters_dict = {
    'optimizer': {
        'values': ['adam', 'sgd']
        },
    'fc_layer_size': {
        'values': [6]
        },
    }

sweep_config['parameters'] = parameters_dict

parameters_dict.update({
    'epochs': {
        'value': 5}
    })

# parameters_dict.update({
#     'learning_rate': {
#         # a flat distribution between 0 and 0.1
#         'distribution': 'uniform',
#         'min': 0,
#         'max': 0.1
#       },
#     'batch_size': {
#         # integers between 32 and 256
#         # with evenly-distributed logarithms
#         'distribution': 'q_log_uniform_values',
#         'q': 8,
#         'min': 8,
#         'max': 32,
#       }
#     })

parameters_dict.update({
    'learning_rate': {
        # a flat distribution between 0 and 0.1
        'values': [0.001, 0.01, 0.1]
      },
    'batch_size': {
        # integers between 32 and 256
        # with evenly-distributed logarithms
        'values': [16, 32]
      }
    })



import pprint
pprint.pprint(sweep_config)




In [None]:

def train_model_resnet50(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        model = resnet50(
            weights=ResNet50_Weights.DEFAULT).to(device)
        model.fc = nn.Linear(2048 , config.fc_layer_size).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_alexnet(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('alexnet'))[-1]
        model = torch.hub.load('pytorch/vision', 'alexnet', weight).to(device)
        model.classifier[6] = nn.Linear(4096 , 10).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_convnext_base(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('convnext_base'))[-1]
        model = torch.hub.load('pytorch/vision', 'convnext_base', weight).to(device)
        model.classifier[2] = nn.Linear(1024 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_densenet161(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('densenet161'))[-1]
        model = torch.hub.load('pytorch/vision', 'densenet161', weight).to(device)
        model.classifier = nn.Linear(2208 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_efficientnet_v2_l(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('efficientnet_v2_l'))[-1]
        model = torch.hub.load('pytorch/vision', 'efficientnet_v2_l', weight).to(device)
        model.classifier[1] = nn.Linear(1280 , 10, bias=True).to(device).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_googlenet(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('googlenet'))[-1]
        model = torch.hub.load('pytorch/vision', 'googlenet', weight).to(device)
        model.fc = nn.Linear(1024 , 10, bias=True).to(device)
        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_inception_v3(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('inception_v3'))[-1]
        model = torch.hub.load('pytorch/vision', 'inception_v3', weight).to(device)
        model.fc = nn.Linear(2048 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_maxvit_t(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('maxvit_t'))[-1]
        model = torch.hub.load('pytorch/vision', 'maxvit_t', weight).to(device)
        model.classifier[5] = nn.Linear(512 , 10, bias=False).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_mobilenet_v3_large(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('mobilenet_v3_large'))[-1]
        model = torch.hub.load('pytorch/vision', 'mobilenet_v3_large', weight).to(device)
        model.classifier[3] = nn.Linear(1280 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model


def train_model_shufflenet_v2_x2_0(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('shufflenet_v2_x2_0'))[-1]
        model = torch.hub.load('pytorch/vision', 'shufflenet_v2_x2_0', weight).to(device)
        model.fc = nn.Linear(2048 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_swin_v2_t(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('swin_v2_t'))[-1]
        model = torch.hub.load('pytorch/vision', 'swin_v2_t', weight).to(device)
        model.head = nn.Linear(768 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_vgg19_bn(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('vgg19_bn'))[-1]
        model = torch.hub.load('pytorch/vision', 'vgg19_bn', weight).to(device)
        model.classifier[6] = nn.Linear(4096 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model

def train_model_wide_resnet50_2(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader, test_dataloader, class_names = get_datalodaer(
            train_data_path = shortTrain,
            test_data_path = shortTest,
            batchsize=config.batch_size)

        weight = list(torchvision.models.get_model_weights('wide_resnet50_2'))[-1]
        model = torch.hub.load('pytorch/vision', 'wide_resnet50_2', weight).to(device)
        model.fc = nn.Linear(2048 , 10, bias=True).to(device)

        train_test_loop(config=config,
                        model=model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        class_names=class_names
                        )
    return model


count=10