In [1]:
import torch
import torch.nn as nn 
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models

In [2]:
class CNN_model(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        
        # Convolution blocks
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channel, 16, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(16, 32, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        # NN block
        self.fully_connected = nn.Sequential(
            nn.Linear(16*16*64, 512),     # 32768 -> 512
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 64),           # Another layer
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(64, 9)              # Output layer (9 waste categories)
        )
    def forward(self,x):
        X = self.feature_extractor(x)
        X = torch.flatten(X,1)
        X = self.fully_connected(X)
        
        return X

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = "models/best_model_Adam.pth"   # update

In [4]:
model = CNN_model(in_channel=3).to(device)

# 2. Load checkpoint 
ckpt = torch.load(ckpt_path, map_location=device)
state_dict = ckpt.get("model_state_dict", ckpt)

# 3) Load weights and move model to device
model.load_state_dict(state_dict)
model = model.to(device)

In [None]:


# 1) Recreate model architecture exactly as during training
num_classes = 9  # set to your number of classes
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# 2) Load checkpoint (map_location to device)
ckpt = torch.load(ckpt_path, map_location=device)

# ckpt may be either a dict with 'model_state_dict' or just a state_dict
state_dict = ckpt.get("model_state_dict", ckpt)

# If state_dict keys were saved under DataParallel ('module.' prefix), strip it:
if any(k.startswith("module.") for k in state_dict.keys()):
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

# 3) Load weights and move model to device
model.load_state_dict(state_dict)
model = model.to(device)