In [None]:
import torch
import torchvision
#import sys
#sys.path.append("..")
import os
os.chdir('..')

from image_classification.classifier import PolysecureClassifier

### Create Classifier 

In [None]:
classifier = PolysecureClassifier()

### Display Mini-Batch Example

In [None]:
sample = next(iter(classifier.dataloaders['train']))
inputs, classes = sample['image'], sample['label']
out = torchvision.utils.make_grid(inputs)
classifier.imshow(out, [classifier.class_names[x] for x in classes])

### Load pretrained ResNet18

In [None]:
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(classifier.class_names))

### Init Training Settings

In [None]:
criterion = torch.nn.CrossEntropyLoss(weight=classifier.get_classes_weight())
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

### Run Training

In [None]:
model_ft = classifier.train_model(model=model,
                                  criterion=criterion,
                                  optimizer=optimizer,
                                  scheduler=exp_lr_scheduler,
                                  num_epochs=2,
                                  model_name='Test',
                                  early_stopping=True)

### Display Confusion Matrix

In [None]:
classifier.print_confusion_matrix(model_ft)

### Plot All Misclassified Images

In [None]:
classifier.print_misclassified(model_ft, plot=True)

### Display Class Activation Map

In [None]:
from image_classification.cam import CAM

model_name = "Test"
model_name += "_weights.pth"
model_ft = torchvision.models.resnet18()
num_ftrs = model_ft.fc.in_features
model_ft.fc = torch.nn.Linear(num_ftrs, len(classifier.class_names))
model_ft.load_state_dict(torch.load(classifier.store_dir + 'model/' + model_name))

cam = CAM(classifier)
cam.print_cam(model_ft)