In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# Assuming `model` is your trained model and `train_loader` and `val_loader` are your data loaders for training and validation
# Also assuming `device` is the device you are using (CPU/GPU)

# Set model to evaluation mode
model.eval()

# Trackers for loss and metrics
train_losses = []
val_losses = []
train_precisions = []
val_precisions = []
train_recalls = []
val_recalls = []
train_f1_scores = []
val_f1_scores = []
train_accuracies = []
val_accuracies = []

# Function to calculate metrics
def calculate_metrics(y_true, y_pred):
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    accuracy = accuracy_score(y_true, y_pred)
    return precision, recall, f1, accuracy

# Training loop (for plotting)
for epoch in range(num_epochs):
    model.train()
    epoch_train_loss = 0
    epoch_train_preds = []
    epoch_train_labels = []

    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device)

        # Forward pass
        outputs = model(data)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Store loss
        epoch_train_loss += loss.item()

        # Collect predictions and labels
        _, preds = torch.max(outputs, 1)
        epoch_train_preds.extend(preds.cpu().numpy())
        epoch_train_labels.extend(labels.cpu().numpy())

    # Calculate training metrics
    avg_train_loss = epoch_train_loss / len(train_loader)
    precision, recall, f1, accuracy = calculate_metrics(epoch_train_labels, epoch_train_preds)

    # Append to lists
    train_losses.append(avg_train_loss)
    train_precisions.append(precision)
    train_recalls.append(recall)
    train_f1_scores.append(f1)
    train_accuracies.append(accuracy)

    # Validation loop
    model.eval()
    epoch_val_loss = 0
    epoch_val_preds = []
    epoch_val_labels = []

    with torch.no_grad():
        for data, labels in val_loader:
            data, labels = data.to(device), labels.to(device)

            # Forward pass
            outputs = model(data)
            loss = criterion(outputs, labels)

            epoch_val_loss += loss.item()

            # Collect predictions and labels
            _, preds = torch.max(outputs, 1)
            epoch_val_preds.extend(preds.cpu().numpy())
            epoch_val_labels.extend(labels.cpu().numpy())

    # Calculate validation metrics
    avg_val_loss = epoch_val_loss / len(val_loader)
    precision, recall, f1, accuracy = calculate_metrics(epoch_val_labels, epoch_val_preds)

    # Append to lists
    val_losses.append(avg_val_loss)
    val_precisions.append(precision)
    val_recalls.append(recall)
    val_f1_scores.append(f1)
    val_accuracies.append(accuracy)

    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {accuracy:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {accuracy:.4f}")

# Plotting loss
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plotting accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# Plotting Precision, Recall, F1-Score
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.plot(train_precisions, label='Train Precision')
plt.plot(val_precisions, label='Validation Precision')
plt.title('Precision Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Precision')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_recalls, label='Train Recall')
plt.plot(val_recalls, label='Validation Recall')
plt.title('Recall Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Recall')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(train_f1_scores, label='Train F1 Score')
plt.plot(val_f1_scores, label='Validation F1 Score')
plt.title('F1 Score Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.legend()

plt.tight_layout()
plt.show()
