In [None]:
def get_model(lr, num_classes, load_path=None, freeze_features=True):
    model = models.efficientnet_b0(pretrained=True)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    if freeze_features:
        # Freeze all layers except the classifier
        for param in model.features.parameters():
            param.requires_grad = False
    if load_path:
        model.load_state_dict(torch.load(load_path, map_location=device))
        model.eval()
    optimizer  = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    return model, optimizer, nn.CrossEntropyLoss()



def loss_batch(model, loss_func, xb, yb, opt=None):
    xb, yb = xb.to(device), yb.to(device)
    preds = model(xb)
    loss = loss_func(preds, yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)


def fit(epochs, model, loss_func, opt, train_dl, valid_dl,which_epoch_unfreeze, writer):
    model.to(device)

    scheduler = lr_scheduler.StepLR(opt, step_size=1, gamma=0.2)  

    for epoch in range(epochs):
        
        if epoch  == which_epoch_unfreeze :
            for param in model.features.parameters():
                param.requires_grad = True
            #params_to_update = [p for p in model.parameters() if p.requires_grad]
            opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)


        total_loss = 0
        count = 0
        model.train()

        progress_bar = tqdm(train_dl, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        
        for xb, yb in progress_bar:
            xb, yb = xb.to(device), yb.to(device)
            loss, batch_size = loss_batch(model, loss_func, xb, yb, opt)
            total_loss += loss * batch_size
            count += batch_size
            progress_bar.set_postfix(train_loss=total_loss / count)
            train_loss = total_loss / count
            writer.add_scalar('Loss/train', train_loss, count)
        
        if scheduler is not None:
            scheduler.step()
        
        

        model.eval()
        val_losses = []
        val_nums = []

        with torch.no_grad():
            for xb, yb in valid_dl:
                xb, yb = xb.to(device), yb.to(device)
                loss, batch_size = loss_batch(model, loss_func, xb, yb)
                val_losses.append(loss)
                val_nums.append(batch_size)

        val_loss = np.sum(np.multiply(val_losses, val_nums)) / np.sum(val_nums)
        writer.add_scalar('Loss/val', val_loss, epoch)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    print(f"Final Validation Loss = {val_loss:.4f}")  

def test_model(model, test_loader, loss_func, class_names=None , writer = None):
    model.eval()
    model.to(device)

    total_correct = 0
    total_samples = 0
    total_loss = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for xb, yb in tqdm(test_loader, desc="Testing", leave=False):
            xb, yb = xb.to(device), yb.to(device)
            outputs = model(xb)
            loss = loss_func(outputs, yb)

            total_loss += loss.item() * xb.size(0)

            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == yb).sum().item()
            total_samples += xb.size(0)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(yb.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    print(f"\nTest Loss: {avg_loss:.4f}")
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names if class_names else range(cm.shape[1]),
                yticklabels=class_names if class_names else range(cm.shape[0]))
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(f"Confusion Matrix\nLoss: {avg_loss:.4f}, Accuracy: {accuracy * 100:.2f}%")
    plt.tight_layout()
    
    if writer:
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        image = Image.open(buf)
        image = transforms.ToTensor()(image)
        writer.add_image("Confusion_Matrix", image ,1)
    plt.show()

    return avg_loss, accuracy, cm
