# Model Training Notebook

In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname('../notebooks'))))

from utils.utils import setup_mlflow, get_device
from utils.data_loader import load_data, load_data_detection
from models.classification.resnet_classification import Resnet18_Classification
from models.detection.resnet_detection import Resnet18_Detection
from evaluation.classification.eval_classification import plot_confusion_matrix, plot_misclassified_images, get_all_preds
from scripts.train_classification import train_classification
from scripts.train_detection import train_detection
import torch
import torchvision.models as models
import mlflow
from sklearn.metrics import classification_report




In [2]:
## Utils:
device = get_device()

Using GPU: NVIDIA GeForce RTX 4090


Classification Training
----------


In [3]:
# data_dir = '../data/coco'
# setup_mlflow("classification", "http://localhost:5000") 
# os.makedirs("../evaluation/classification", exist_ok=True)


# with mlflow.start_run():
#     mlflow.log_param("epochs", 50)
#     mlflow.log_param("batch_size", 128)
#     mlflow.log_param("learning_rate", 0.003)
#     mlflow.log_param("momentum", 0.9)
    
#     net = Resnet18_Classification().to(device)
#     trainloader, testloader, class_names = load_data(data_dir, batch_size=128, shuffle=True, resize_x=224, resize_y=224)
#     trained_net = train_classification(net, trainloader, testloader, device, num_epochs=5)

#     save_path = '../models_saved/classification.pth'
#     torch.save(trained_net.state_dict(), save_path)
#     mlflow.pytorch.log_model(trained_net, "model")
#     mlflow.log_artifact(save_path)


#     plot_confusion_matrix(trained_net, testloader, class_names, device, save_path="../evaluation/classification/confusion_matrix.png")
#     y_pred, y_true, misclassified_images, misclassified_labels, misclassified_preds = get_all_preds(trained_net, testloader, device)
#     plot_misclassified_images(misclassified_images, misclassified_labels, misclassified_preds, class_names, num_images=5)
    
#     class_report = classification_report(y_true, y_pred, target_names=class_names)
#     print(class_report)
    
#     class_report = classification_report(y_true, y_pred, target_names=class_names)
#     print(class_report)
    
#     report_path = "../evaluation/classification/classification_report.txt"
#     with open(report_path, "w") as f:
#         f.write(class_report)
#     mlflow.log_artifact(report_path)

    
    


    
    

Detection Training
----------


In [4]:
import mlflow
import mlflow.pytorch

data_dir = '../data/coco'


mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("detection")

with mlflow.start_run():
    mlflow.log_param("epochs", 28)
    mlflow.log_param("batch_size", 50)
    mlflow.log_param("learning_rate", 0.001)
    mlflow.log_param("momentum", 0.9)
    
    net = Resnet18_Detection().to(device)
    trainloader, valloader = load_data_detection(data_dir, batch_size=50, shuffle=True)
    trained_net = train_detection(net, trainloader, valloader, device, num_epochs=50)

    save_path = '../models_saved/detection.pth'
    torch.save(trained_net.state_dict(), save_path)
    mlflow.pytorch.log_model(trained_net, "model")
    mlflow.log_artifact(save_path)



loading annotations into memory...
Done (t=9.16s)
creating index...
index created!
loading annotations into memory...
Done (t=0.42s)
creating index...
index created!


Epoch 1: 100%|██████████| 83/83 [00:38<00:00,  2.16it/s]


Epoch 1: Train Loss: 0.845, Val Loss: 0.527


Epoch 2: 100%|██████████| 83/83 [00:36<00:00,  2.26it/s]


Epoch 2: Train Loss: 0.358, Val Loss: 0.256


Epoch 3: 100%|██████████| 83/83 [01:25<00:00,  1.03s/it]


Epoch 3: Train Loss: 0.219, Val Loss: 0.204


Epoch 4: 100%|██████████| 83/83 [00:36<00:00,  2.25it/s]


Epoch 4: Train Loss: 0.191, Val Loss: 0.183


Epoch 5: 100%|██████████| 83/83 [00:35<00:00,  2.33it/s]


Epoch 5: Train Loss: 0.178, Val Loss: 0.175


Epoch 6: 100%|██████████| 83/83 [00:36<00:00,  2.30it/s]


Epoch 6: Train Loss: 0.168, Val Loss: 0.166


Epoch 7: 100%|██████████| 83/83 [00:36<00:00,  2.28it/s]


Epoch 7: Train Loss: 0.160, Val Loss: 0.163


Epoch 8: 100%|██████████| 83/83 [00:35<00:00,  2.33it/s]


Epoch 8: Train Loss: 0.154, Val Loss: 0.160


Epoch 9: 100%|██████████| 83/83 [00:36<00:00,  2.30it/s]


Epoch 9: Train Loss: 0.149, Val Loss: 0.152


Epoch 10: 100%|██████████| 83/83 [00:37<00:00,  2.24it/s]


Epoch 10: Train Loss: 0.144, Val Loss: 0.150


Epoch 11: 100%|██████████| 83/83 [00:36<00:00,  2.30it/s]


Epoch 11: Train Loss: 0.140, Val Loss: 0.149


Epoch 12: 100%|██████████| 83/83 [00:42<00:00,  1.93it/s]


Epoch 12: Train Loss: 0.137, Val Loss: 0.145


Epoch 13: 100%|██████████| 83/83 [00:47<00:00,  1.76it/s]


Epoch 13: Train Loss: 0.132, Val Loss: 0.140


Epoch 14: 100%|██████████| 83/83 [00:37<00:00,  2.24it/s]


Epoch 14: Train Loss: 0.131, Val Loss: 0.141


Epoch 15: 100%|██████████| 83/83 [00:42<00:00,  1.95it/s]


Epoch 15: Train Loss: 0.127, Val Loss: 0.134


Epoch 16: 100%|██████████| 83/83 [02:28<00:00,  1.79s/it]


Epoch 16: Train Loss: 0.125, Val Loss: 0.135


Epoch 17: 100%|██████████| 83/83 [00:51<00:00,  1.61it/s]


Epoch 17: Train Loss: 0.123, Val Loss: 0.131


Epoch 18: 100%|██████████| 83/83 [00:37<00:00,  2.19it/s]


Epoch 18: Train Loss: 0.121, Val Loss: 0.132


Epoch 19: 100%|██████████| 83/83 [00:39<00:00,  2.11it/s]


Epoch 19: Train Loss: 0.119, Val Loss: 0.129


Epoch 20: 100%|██████████| 83/83 [00:57<00:00,  1.44it/s]


Epoch 20: Train Loss: 0.117, Val Loss: 0.129


Epoch 21: 100%|██████████| 83/83 [01:28<00:00,  1.06s/it]


Epoch 21: Train Loss: 0.117, Val Loss: 0.130


Epoch 22: 100%|██████████| 83/83 [01:29<00:00,  1.08s/it]


Epoch 22: Train Loss: 0.116, Val Loss: 0.124


Epoch 23: 100%|██████████| 83/83 [00:37<00:00,  2.22it/s]


Epoch 23: Train Loss: 0.114, Val Loss: 0.125


Epoch 24: 100%|██████████| 83/83 [00:36<00:00,  2.26it/s]


Epoch 24: Train Loss: 0.113, Val Loss: 0.127


Epoch 25: 100%|██████████| 83/83 [00:39<00:00,  2.08it/s]


Epoch 25: Train Loss: 0.112, Val Loss: 0.125


Epoch 26: 100%|██████████| 83/83 [00:35<00:00,  2.32it/s]


Epoch 26: Train Loss: 0.111, Val Loss: 0.122


Epoch 27: 100%|██████████| 83/83 [00:37<00:00,  2.23it/s]


Epoch 27: Train Loss: 0.111, Val Loss: 0.122


Epoch 28: 100%|██████████| 83/83 [00:42<00:00,  1.95it/s]


Epoch 28: Train Loss: 0.111, Val Loss: 0.125


Epoch 29: 100%|██████████| 83/83 [00:48<00:00,  1.73it/s]


Epoch 29: Train Loss: 0.111, Val Loss: 0.122


Epoch 30: 100%|██████████| 83/83 [00:48<00:00,  1.70it/s]


Epoch 30: Train Loss: 0.111, Val Loss: 0.121


Epoch 31: 100%|██████████| 83/83 [00:48<00:00,  1.71it/s]


Epoch 31: Train Loss: 0.111, Val Loss: 0.122


Epoch 32: 100%|██████████| 83/83 [01:09<00:00,  1.19it/s]


Epoch 32: Train Loss: 0.110, Val Loss: 0.123


Epoch 33: 100%|██████████| 83/83 [01:02<00:00,  1.32it/s]


Epoch 33: Train Loss: 0.111, Val Loss: 0.123


Epoch 34: 100%|██████████| 83/83 [00:47<00:00,  1.73it/s]


Epoch 34: Train Loss: 0.110, Val Loss: 0.124


Epoch 35: 100%|██████████| 83/83 [00:35<00:00,  2.33it/s]


Epoch 35: Train Loss: 0.110, Val Loss: 0.120


Epoch 36: 100%|██████████| 83/83 [00:36<00:00,  2.28it/s]


Epoch 36: Train Loss: 0.110, Val Loss: 0.123


Epoch 37: 100%|██████████| 83/83 [00:35<00:00,  2.32it/s]


Epoch 37: Train Loss: 0.110, Val Loss: 0.122


Epoch 38: 100%|██████████| 83/83 [00:37<00:00,  2.23it/s]


Epoch 38: Train Loss: 0.110, Val Loss: 0.125


Epoch 39: 100%|██████████| 83/83 [00:37<00:00,  2.24it/s]


Epoch 39: Train Loss: 0.111, Val Loss: 0.123


Epoch 40: 100%|██████████| 83/83 [00:36<00:00,  2.27it/s]


Epoch 40: Train Loss: 0.110, Val Loss: 0.121
Early stopping at epoch 40


  net.load_state_dict(torch.load('best_model.pth'))


🏃 View run suave-wasp-969 at: http://localhost:5000/#/experiments/155698513309176258/runs/69e1d5d55bf941a1a421f067b57efa95
🧪 View experiment at: http://localhost:5000/#/experiments/155698513309176258


Segmentation Training
----------


In [5]:
# mlflow.set_experiment("COCO_ResNet_Segmentation_Training")
# with mlflow.start_run(run_name="Segmentation_Model"):
#     train_segmentation()  