In [None]:
import pandas as pd
import numpy as np
from prettytable import PrettyTable
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import fbeta_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ----------------------------
# Submission Generation Utils
# ----------------------------
def generate_submission(predictions, experiment_ids, particle_types, output_file="submission.csv"):
    if not (len(predictions) == len(experiment_ids) == len(particle_types)):
        raise ValueError("Input lists must have the same length.")

    results = []
    for idx, (pred, experiment, particle_type) in enumerate(zip(predictions, experiment_ids, particle_types)):
        if isinstance(pred, tuple) and len(pred) == 3:
            x, y, z = pred
        else:
            raise ValueError("Each prediction must be a tuple of three floats (x, y, z).")

        results.append({
            "id": idx,
            "experiment": experiment,
            "particle_type": particle_type,
            "x": x,
            "y": y,
            "z": z
        })

    submission_df = pd.DataFrame(results)
    submission_df.to_csv(output_file, index=False)

# ----------------------------
# Dataset Class
# ----------------------------
class CryoETDataset(Dataset):
    def __init__(self, data, labels=None, augment=False):
        self.data = data
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx] if self.labels is not None else None
        if self.augment:
            x = self.augment_data(x)
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long) if y is not None else None

    def augment_data(self, x):
        if torch.rand(1).item() > 0.5:
            x = x.flip(dims=[0])
        return x

# ----------------------------
# Model Definition
# ----------------------------
class CryoETModel(nn.Module):
    def __init__(self, num_classes=5):
        super(CryoETModel, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16 * 16, 128)  # Adjust if your input size changes
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool3d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool3d(x, 2)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# ----------------------------
# Training & Evaluation
# ----------------------------
def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for x, y in tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        val_accuracy = evaluate_model(model, val_loader)
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), "best_model.pth")

        print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Accuracy: {val_accuracy:.4f}")

def evaluate_model(model, data_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in data_loader:
            outputs = model(x)
            _, predicted = torch.max(outputs, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

def calculate_fbeta(y_true, y_pred, beta=4):
    return fbeta_score(y_true, y_pred, beta=beta, average='micro')

# ----------------------------
# Submission Preparation
# ----------------------------
def prepare_submission(model, test_loader, output_file="/content/submission.csv"):
    model.eval()
    results = []
    with torch.no_grad():
        for i, (x, _) in enumerate(test_loader):
            outputs = model(x)
            _, predicted = torch.max(outputs, 1)
            for idx, pred in enumerate(predicted):
                results.append({
                    "id": i * len(predicted) + idx,
                    "experiment": "TS_5_4",
                    "particle_type": ["ribosome", "virus-like", "apo-ferritin", "thyroglobulin", "β-galactosidase"][pred],
                    "x": np.random.uniform(),
                    "y": np.random.uniform(),
                    "z": np.random.uniform()
                })

    submission_df = pd.DataFrame(results)
    submission_df.to_csv(output_file, index=False)

submission_df = pd.read_csv("/content/submission.csv")
# Display first 10 rows using PrettyTable
table = PrettyTable()
table.field_names = submission_df.columns.tolist()
for _, row in submission_df.head(10).iterrows():
    table.add_row(row.tolist())
print(table)

+----+------------+---------------------+----------+---------+---------+
| id | experiment |    particle_type    |    x     |    y    |    z    |
+----+------------+---------------------+----------+---------+---------+
| 0  |   TS_5_4   |     beta-amylase    | 2983.596 | 3154.13 | 764.124 |
| 1  |   TS_5_4   |  beta-galactosidase | 2983.596 | 3154.13 | 764.124 |
| 2  |   TS_6_4   |       ribosome      | 2983.596 | 3154.13 | 764.124 |
| 3  |   TS_6_4   |     apo-ferritin    | 2983.596 | 3154.13 | 764.124 |
| 4  |  TS_69_2   | virus-like-particle | 2983.596 | 3154.13 | 764.124 |
| 5  |   TS_5_4   |     beta-amylase    | 2983.596 | 3154.13 | 764.124 |
| 6  |   TS_5_4   |  beta-galactosidase | 2983.596 | 3154.13 | 764.124 |
| 7  |   TS_6_4   |       ribosome      | 2983.596 | 3154.13 | 764.124 |
| 8  |   TS_6_4   |     apo-ferritin    | 2983.596 | 3154.13 | 764.124 |
| 9  |  TS_69_2   | virus-like-particle | 2983.596 | 3154.13 | 764.124 |
+----+------------+---------------------+----------