## Imports

In [None]:
!hostnamectl

In [None]:
%cd /home/ir739wb/ilyarekun/bc_project/centralized-learning/src/

In [None]:
import sys
import os
sys.path.append('../src')
from data_preprocessing import data_preprocessing_tumor_stratified
from model import BrainCNN, EarlyStopping
from collections import defaultdict

import numpy as np
import random
import torch
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix


In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False


In [None]:
train_loader, valid_loader, test_loader = data_preprocessing_tumor_stratified()
print("data was successfully loaded")

In [None]:
print(f"Train dataset size: {len(train_loader.dataset)}")
print(f"Validation dataset size: {len(valid_loader.dataset)}")
print(f"Test dataset size: {len(test_loader.dataset)}")


In [None]:

def count_images_per_class(loader):
    class_counts = defaultdict(int)

    for _, labels in loader:
        for label in labels:
            class_counts[label.item()] += 1  

    return class_counts

train_class_counts = count_images_per_class(train_loader)
valid_class_counts = count_images_per_class(valid_loader)
test_class_counts = count_images_per_class(test_loader)

print("Train loader class counts:")
for class_label, count in train_class_counts.items():
    print(f"Class {class_label}: {count} images")

print("\nValidation loader class counts:")
for class_label, count in valid_class_counts.items():
    print(f"Class {class_label}: {count} images")

print("\nTest loader class counts:")
for class_label, count in test_class_counts.items():
    print(f"Class {class_label}: {count} images")


In [None]:


model = BrainCNN()

train_loss_metr, val_loss_metr, train_acc_metr, val_acc_metr, early_stopping = model.train_model(train_loader, valid_loader, num_epochs=50, patience=6, delta=0.004, learning_rate=0.002,momentum = 0.85, weight_decay = 0.07, save_path="./braincnn_prototype.weights")



In [None]:
import matplotlib.pyplot as plt



epochs = range(1, len(train_loss_metr) + 1)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs, train_loss_metr, label="Train Loss", marker="o")
plt.plot(epochs, val_loss_metr, label="Validation Loss", marker="o")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, train_acc_metr, label="Train Accuracy", marker="o")
plt.plot(epochs, val_acc_metr, label="Validation Accuracy", marker="o")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracy")
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
save_path="./braincnn_prototype.weights"

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Создай модель
model = BrainCNN()  # Убедись, что создаёшь объект модели
model.to(device)

# Загрузите веса без изменения
state_dict = torch.load(save_path, map_location=device)

# Если модель была обучена с DataParallel, убери "module."
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

# Теперь можно обернуть в DataParallel, если используешь несколько GPU
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)


In [None]:


# Final evaluation on the test set
model.eval()
correct = 0
total = 0
test_targets = []
test_preds = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        test_targets.extend(target.cpu().numpy())
        test_preds.extend(predicted.cpu().numpy())

test_accuracy =  correct / total

precision = precision_score(test_targets, test_preds, average='weighted')
recall = recall_score(test_targets, test_preds, average='weighted')
f1 = f1_score(test_targets, test_preds, average='weighted')

# Print the results
print('Metrics of the model on the test images:')
print(f'Accuracy: {test_accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

with open("training_metrics.pkl", "wb") as f:
    pickle.dump({
        "train_loss": train_loss_metr,
        "val_loss": val_loss_metr,
        "train_acc": train_acc_metr,
        "val_acc": val_acc_metr,
        "accuracy": test_accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }, f)

#torch.save(model.state_dict(), "./braincnn_prototype.weights")

In [None]:
# Load metrics
with open("training_metrics1.pkl", "rb") as f:
    metrics = pickle.load(f)

train_loss_metr = metrics["train_loss"]
val_loss_metr = metrics["val_loss"]
train_acc_metr = metrics["train_acc"]
val_acc_metr = metrics["val_acc"]

plt.figure(figsize=(12, 5))

# Loss curve
plt.subplot(1, 2, 1)
plt.plot(train_loss_metr, label='Train Loss')
plt.plot(val_loss_metr, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')

# Accuracy curve
plt.subplot(1, 2, 2)
plt.plot(train_acc_metr, label='Train Accuracy')
plt.plot(val_acc_metr, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')

plt.savefig("training_plots.png", dpi=300, bbox_inches="tight")  # Сохранение в файл
plt.show()


In [None]:


cm = confusion_matrix(test_targets, test_preds)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
            xticklabels=['Glioma', 'Meningioma', 'notumor', 'Putuitary'], 
            yticklabels=['Glioma', 'Meningioma', 'notumor', 'Putuitary'])

plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight")
plt.show()

