## Training Classifier on Heavy Dirt and Clean Images

In [1]:
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from sys import platform

# Resnet is Resnet 18 (18 Residual SKip Blocks, vs 50)
from models.resnet import ResNet, ResNet50

# These libraries hold dataset information
from dataset import DataloopDataset,DataloopDatasetClean,NonDataloopDatasetSubset

# This is the Trainer Library
from training_wheels import TrainingWheels

# Library that allows for to alter "raw" images. Allows for random rotation, flipping, and adding "noise"
from augmentation import Augmentation

from datetime import datetime
from PIL import Image
import matplotlib.pyplot as plt

from argparse import ArgumentParser
import os

In [11]:
hparams = {
    'dataset_dir':"dataloop",
    'model':"resnet18",
    'use_wandb':True,
    'enable_random_cropping':False,
    'batch_size':2,
    'num_gpus':1,
    'downscaling_width':1224,
    'downscaling_height':1632,
    'max_epochs':100,
    'accelerator':None,
    'devices':None,
    'use_dali':False,
    'center_crop':448,
    'enable_vertical_mirroring':False,
    'enable_horizontal_mirroring':True,
    'random_rotation_angle':15,
    'noise_amount':0,
    'resume_from_checkpoint':None,
    'enable_image_logging':True,
    'lr':0.0001,
    'balance_sampler':False,
    'train_sample_size':250
}

In [12]:
# datetime object containing current date and time
now = datetime.now()
dt_string = now.strftime("%m/%d/%Y %H:%M")

In [13]:
dataset_dir = hparams['dataset_dir']

# This loads the augmentation object (see the keywords and how they are used inside )
augmentation = Augmentation(enable_random_cropping=hparams['enable_random_cropping'],
                            enable_vertical_mirroring=hparams['enable_vertical_mirroring'],
                            enable_horizontal_mirroring=hparams['enable_horizontal_mirroring'],
                            random_rotation_angle=hparams['random_rotation_angle'],
                            noise_amount=hparams['noise_amount'],
                            downscaling_width=hparams['downscaling_width'],
                            downscaling_height=hparams['downscaling_height'])

In [14]:
dataset = DataloopDataset(dataset_dir=dataset_dir,
                          train=True,
                          augmentation=augmentation)

In [15]:
batch_size = hparams['batch_size']
logger = WandbLogger(project="car-condition-classifier",log_model='all',name=dt_string) if hparams['use_wandb'] else None
if hparams['model'] == "efficientnet":
    model = EfficientNetV2()
elif hparams['model'] == "resnet18":
    model = ResNet()
elif hparams['model'] == "resnet50":
    model = ResNet50()
else:
    assert False, "Unknown model: {}".format(hparams['model'])
accelerator = "gpu" if hparams['num_gpus'] > 0 else None



In [16]:
if hparams['use_dali']:
    training_wheels = TrainingWheelsDALI(model=model,
                                         dataset=dataset,
                                         batch_size=batch_size,
                                         lr=hparams['lr'])
    if 'PL_TRAINER_GPUS' in os.environ:
        os.environ.pop('PL_TRAINER_GPUS')
    trainer = pl.Trainer(max_epochs=int(hparams['max_epochs']), logger=logger, accelerator=accelerator,
                         devices=max(hparams['num_gpus'], 1))
else:
    training_wheels = TrainingWheels(model=model,
                                     dataset=dataset,
                                     batch_size=batch_size,
                                     augmentation=augmentation,
                                     lr=hparams['lr'],
                                     enable_image_logging=hparams['enable_image_logging'],
                                     validation_set_size=0.2,
                                     balance_sampler=hparams['balance_sampler'],
                                     balance_sample_size=hparams['train_sample_size']
                                    )
    trainer = pl.Trainer(max_epochs=int(hparams['max_epochs']), logger=logger, accelerator=accelerator,
                         devices=max(hparams['num_gpus'], 1))

/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
trainer.fit(training_wheels, ckpt_path=hparams['resume_from_checkpoint'])

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [18]:
wandb.finish()



0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▆█▆▃▃▃▂▆▄▄▆▄▂▁▂▁▂▅▂▂▁▁▅▁▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_acc,▅▇▁▇█▆▇▄▇▇█▇▇█▅▇█▇██▇▆▅▇▇█████▇▇██▇▇▆█▇█
val_loss,▄▂█▂▂▄▂▅▂▂▁▂▂▁▅▂▁▂▁▁▂▃▅▃▃▁▁▁▁▁▂▂▁▁▂▂▃▁▂▁

0,1
epoch,99.0
train_loss,0.00111
trainer/global_step,12499.0
val_acc,0.94692
val_loss,0.16657
