In [1]:

!pip install torch-geometric torch-scatter


Collecting torch-geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l- \ | / done
[?25h  Getting requirements to build wheel ... [?25l- done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- done
[?25hCollecting torch-scatter
  Downloading torch_scatter-2.1.1.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
Building wheels for collected packages: torch-geometric, torch-scatter
  Building wheel for torch-geometric (pyproject.toml) ... [?25l- \ | / - \ done
[?25h  Created wheel for torch-geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910454 sha256=7f07beece81c2093b7aa3c71d13bc6f2558e3946c9fcbca370b7c269

In [2]:

import numpy as np  # 🧮 NumPy for numerical computations
import pandas as pd  # 🐼 Pandas for data manipulation
import os  # 📂 Operating system-related functions
from tqdm import tqdm  # 🔄 tqdm for progress bar visualization

import sklearn  # 🧬 scikit-learn for machine learning utilities
import sklearn.model_selection  # 📊 scikit-learn's model selection module
import torch  # 🔥 PyTorch for deep learning
from torch import nn  # 🧠 PyTorch's neural network module
from torch import Tensor  # 🚀 PyTorch's Tensor data type
from torch_geometric.nn import GCNConv  # 📊 Graph Convolutional Network layer
from torch_geometric.datasets import Planetoid  # 🌍 PyTorch Geometric dataset for graph data
from torch.utils.data import DataLoader, Dataset  # 📦 PyTorch data loading utilities
from timm.scheduler import CosineLRScheduler  # 📈 Learning rate scheduler
import matplotlib.pyplot as plt  # 📊 Matplotlib for plotting

device = 'cuda' if torch.cuda.is_available() else 'cpu'  # ⚙️ Determine if CUDA (GPU) is available




In [3]:
# 📁 Define a Function to Load DataFrames
# This function loads data stored in different splits (train, valid, test) from a specified directory.
# It reads files in the directory, extracts data using NumPy, and organizes it into DataFrames.

def load_df(directory):
    splits = ["train", "valid", "test"]  # 🔄 List of data splits
    dfs = dict()  # 📊 Dictionary to store DataFrames for each split
    
    for split in splits:
        path = os.path.join(directory, split)  # 📂 Define the path to the split's directory
        files = os.listdir(path)  # 🗂 Get a list of files in the split's directory
        list_df = []  # 📄 List to store data dictionaries
        
        for file in files:
            d = dict(np.load(os.path.join(path, file)))  # 📦 Load data using NumPy
            d['file'] = file  # 📄 Include the file name in the data dictionary
            list_df.append(d)  # 🧾 Append the data dictionary to the list
        dfs[split] = pd.DataFrame.from_dict(list_df)  # 🐼 Create a DataFrame from the list of data dictionaries and store it in the dictionary
    return dfs

# 📄 Load data using the defined function and store it in the 'tile_xla' variable
tile_xla = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/tile/xla/")


# 📦 Define Dataset and Model

In [4]:
# 📦 Define Custom Dataset Class
# This class, 'TileDataset', is a custom dataset class for our machine learning task.
# It inherits from the PyTorch 'Dataset' class and implements the necessary methods (__init__, __len__, and __getitem__).

class TileDataset(Dataset):
    def __init__(self, df):
        self.df = df  # 💼 Initialize the dataset with a DataFrame containing the data

    def __len__(self):
        return len(self.df)  # 🔢 Define the length of the dataset, which is the number of rows in the DataFrame

    def __getitem__(self, idx):
        row = self.df.iloc[idx]  # 📄 Get a specific row from the DataFrame based on the provided index
        config_feat = torch.tensor(row['config_feat'].astype(np.float32))  # 🧮 Convert and store 'config_feat' as a PyTorch tensor
        node_feat = torch.tensor(row['node_feat'].astype(np.float32))  # 🧮 Convert and store 'node_feat' as a PyTorch tensor
        node_opcode = torch.tensor(row['node_opcode'].astype(np.int32))  # 🧮 Convert and store 'node_opcode' as a PyTorch tensor
        edge_index = torch.tensor(np.swapaxes(row['edge_index'],0,1).astype(np.int32))  # 🧮 Convert and store 'edge_index' as a PyTorch tensor with axis swapping
        target = (row['config_runtime'] / (row['config_runtime_normalizers'] + 1e-5)).astype(np.float32)  # 📈 Calculate and store the target value with preprocessing
        # 📊 Min-max scale the target value to ensure it's within a specific range (standardization)
        target = (target - np.mean(target)) / (np.std(target) + 1e-5)
        target = torch.tensor(target)  # 🧮 Convert and store the target as a PyTorch tensor
        return config_feat, node_feat, node_opcode, edge_index, target  # 🔁 Return the data and target for a specific sample

# This class defines the structure of our custom dataset, converting and preprocessing data as necessary for training and evaluation.
# The relevant emojis provide a visual context for each part of the code.


In [5]:
# 🧠 Define Simple Neural Network Model
# In this cell, we define a simple neural network model named 'SimpleModel'.
# This model takes input data with specified dimensions and passes it through convolutional and dense layers.

class SimpleModel(torch.nn.Module):
    def __init__(self, hidden_channels, graph_feats, hidden_dim):
        super().__init__()  # 🧬 Initialize the parent class 'torch.nn.Module'
        
        op_embedding_dim = 4  # I choose 4-dimensional embedding
        self.embedding = torch.nn.Embedding(120,  # 120 different op-codes
                                            op_embedding_dim,
                                           )
        assert len(hidden_channels) > 0
        in_channels = op_embedding_dim + 140
        self.convs = torch.nn.ModuleList()
        last_dim = hidden_channels[0]
        
        # Create a sequence of Graph Convolutional Network (GCN) layers
        self.convs.append(GCNConv(in_channels, hidden_channels[0]))
        for i in range(len(hidden_channels) - 1):
            self.convs.append(GCNConv(hidden_channels[i], hidden_channels[i+1]))
            last_dim = hidden_channels[i+1]
        self.convs.append(GCNConv(last_dim, graph_feats))
        
        # Define a sequential dense neural network
        self.dense = torch.nn.Sequential(nn.Linear(graph_feats + 24, 64),
                                         nn.ReLU(),
                                         nn.Linear(64, 64),
                                         nn.ReLU(),
                                         nn.Linear(64, 1),
                                        )

    def forward(self, x_cfg: Tensor, x_feat: Tensor, x_op: Tensor, edge_index: Tensor) -> Tensor:
        
        # Get graph features
        x = torch.cat([x_feat, self.embedding(x_op)], dim=1)  # 📊 Concatenate input features with opcode embeddings
        
        # Pass data through convolutional layers
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        
        # Get 1D graph embedding using average pooling
        x_graph = torch.mean(x, 0)
        
        # Combine graph data with config data
        x = torch.cat([x_cfg, x_graph.repeat((len(x_cfg), 1))], axis=1)  # 🔄 Concatenate config data with repeated graph embeddings
        
        # Pass the combined data through the dense neural network
        x = torch.flatten(self.dense(x))
        
        # Standardize the output
        x = (x - torch.mean(x)) / (torch.std(x) + 1e-5)
        return x

# Create an instance of the 'SimpleModel' and move it to the specified device (CPU or GPU)
model = SimpleModel(hidden_channels=[16, 32, 16, 48], graph_feats=64, hidden_dim=64).to(device)


# 🚂 Train  Epoch

In [6]:
# 📊 Concatenate DataFrames
# In this cell, we concatenate DataFrames 'train' and 'valid' from the 'tile_xla' dictionary along the row axis.
# We then reset the index of the resulting DataFrame for consistent indexing.

# Concatenate 'train' and 'valid' DataFrames along the row axis and reset the index
df = pd.concat((tile_xla["train"], tile_xla["valid"]), axis=0).reset_index(drop=True)

# This operation combines the training and validation data for further processing, ensuring a unified DataFrame.


In [7]:
# 🔄 Cross-Validation Training Loop (Enhanced)

# Define the score_tile_mean function
def score_tile_mean(predictions, df):
    score = 0
    for i in range(len(df)):
        predbest = np.mean(df.iloc[i]['config_runtime'][predictions[i]])
        best = np.mean(np.sort(df.iloc[i]['config_runtime'])[:5])
        score += 2 - predbest / best
    score /= len(df)
    return score

# Define the score_tile_max function
def score_tile_max(predictions, df):
    score = 0
    for i in range(len(df)):
        predbest = np.min(df.iloc[i]['config_runtime'][predictions[i]])
        best = np.min(df.iloc[i]['config_runtime'])
        score += 2 - predbest / best
    score /= len(df)
    return score

# Create a K-Fold cross-validator with 5 splits
kfold = sklearn.model_selection.KFold(n_splits=5, shuffle=True, random_state=0)

# Lists to store mean and max scores for each fold
score_means = []
score_maxs = []

# Define hyperparameters
learning_rate = 5e-4  # Adjust the learning rate to a different value
weight_decay = 1e-6  # Adjust weight decay to a different value
num_epochs = 90  # You can keep the number of epochs as 90 or adjust as needed


# Iterate through each fold
for fold, (tr_idx, va_idx) in enumerate(kfold.split(df)):
    train_dataset = TileDataset(df.iloc[tr_idx])
    val_dataset = TileDataset(df.iloc[va_idx])
    criterion = torch.nn.MSELoss()
    steps = len(train_dataset) * num_epochs  # Update the number of training steps
    warmup_steps = int(steps * 0.1)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = CosineLRScheduler(optimizer, t_initial=steps, warmup_t=warmup_steps, warmup_lr_init=1e-6, lr_min=2e-8)

    best_score = 0
    best_score_max = 0

    # Training loop with increased epochs
    for epoch in range(num_epochs):
        model.train()
        pbar = tqdm(range(len(train_dataset)), leave=False)
        loss_sum = 0
        n = 0
        
        for i in pbar:
            cfg_ft, nd_ft, nd_op, ind, target = train_dataset[i]
            cfg_ft, nd_ft, nd_op, ind, target = cfg_ft.to(device), nd_ft.to(device), nd_op.to(device), ind.to(device), target.to(device)
            
            out = model(cfg_ft, nd_ft, nd_op, ind)
            loss = criterion(out, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
            scheduler.step(i + len(train_dataset) * epoch)
            optimizer.step()
            loss_sum += loss.item()
            n += 1
            pbar.set_description(f'running loss: {(loss_sum/n):.2f}, current loss: {(loss.item()):.2f}')
        pbar.close()
        model.eval()
        tile_xla_predictions = []
        pbar = tqdm(range(len(val_dataset)), leave=False)
        
        for i in pbar:
            cfg_ft, nd_ft, nd_op, ind, target = val_dataset[i]
            cfg_ft, nd_ft, nd_op, ind, target = cfg_ft.to(device), nd_ft.to(device), nd_op.to(device), ind.to(device), target.to(device)
            
            out = model(cfg_ft, nd_ft, nd_op, ind)
            tile_xla_predictions.append(np.argsort(out.cpu().detach().numpy())[:5])
        pbar.close()
        
        # Calculate and display scores for the current fold and epoch
        score_mean = score_tile_mean(tile_xla_predictions, val_dataset.df)
        score_max = score_tile_max(tile_xla_predictions, val_dataset.df)
        print(f'fold {fold} epoch {epoch}, comp_score = {score_max:.3f}, mean_score = {score_mean:.3f},')
        
        # Update best scores and save the model if the mean score improves
        if score_mean > best_score:
            best_score = score_mean
            best_score_max = score_max
            torch.save(model.state_dict(), f'best_model_{fold}.pth')
    
    # Append the best scores for this fold to the respective lists
    score_means.append(best_score)
    score_maxs.append(best_score_max)

# Calculate and display the mean scores across all folds
print(f'comp_score = {np.mean(score_maxs)}, mean_score = {np.mean(score_means)},')


                                                    

fold 0 epoch 0, comp_score = 0.611, mean_score = 0.020,


                                                    

fold 0 epoch 1, comp_score = 0.602, mean_score = 0.028,


                                                    

fold 0 epoch 2, comp_score = 0.757, mean_score = 0.205,


                                                    

fold 0 epoch 3, comp_score = 0.649, mean_score = 0.113,


                                                    

fold 0 epoch 4, comp_score = 0.837, mean_score = 0.648,


                                                    

fold 0 epoch 5, comp_score = 0.654, mean_score = 0.365,


                                                    

fold 0 epoch 6, comp_score = 0.824, mean_score = 0.664,


                                                    

fold 0 epoch 7, comp_score = 0.889, mean_score = 0.769,


                                                    

fold 0 epoch 8, comp_score = 0.919, mean_score = 0.826,


                                                    

fold 0 epoch 9, comp_score = 0.903, mean_score = 0.788,


                                                    

fold 0 epoch 10, comp_score = 0.933, mean_score = 0.849,


                                                    

fold 0 epoch 11, comp_score = 0.929, mean_score = 0.838,


                                                    

fold 0 epoch 12, comp_score = 0.929, mean_score = 0.850,


                                                    

fold 0 epoch 13, comp_score = 0.928, mean_score = 0.841,


                                                    

fold 0 epoch 14, comp_score = 0.931, mean_score = 0.839,


                                                    

fold 0 epoch 15, comp_score = 0.927, mean_score = 0.835,


                                                    

fold 0 epoch 16, comp_score = 0.923, mean_score = 0.836,


                                                    

fold 0 epoch 17, comp_score = 0.927, mean_score = 0.837,


                                                    

fold 0 epoch 18, comp_score = 0.925, mean_score = 0.834,


                                                    

fold 0 epoch 19, comp_score = 0.926, mean_score = 0.831,


                                                    

fold 0 epoch 20, comp_score = 0.919, mean_score = 0.813,


                                                    

fold 0 epoch 21, comp_score = 0.921, mean_score = 0.813,


                                                    

fold 0 epoch 22, comp_score = 0.929, mean_score = 0.826,


                                                    

fold 0 epoch 23, comp_score = 0.929, mean_score = 0.838,


                                                    

fold 0 epoch 24, comp_score = 0.930, mean_score = 0.838,


                                                    

fold 0 epoch 25, comp_score = 0.917, mean_score = 0.820,


                                                    

fold 0 epoch 26, comp_score = 0.930, mean_score = 0.841,


                                                    

fold 0 epoch 27, comp_score = 0.927, mean_score = 0.836,


                                                    

fold 0 epoch 28, comp_score = 0.912, mean_score = 0.812,


                                                    

fold 0 epoch 29, comp_score = 0.924, mean_score = 0.825,


                                                    

fold 0 epoch 30, comp_score = 0.925, mean_score = 0.836,


                                                    

fold 0 epoch 31, comp_score = 0.927, mean_score = 0.841,


                                                    

fold 0 epoch 32, comp_score = 0.930, mean_score = 0.839,


                                                    

fold 0 epoch 33, comp_score = 0.923, mean_score = 0.825,


                                                    

fold 0 epoch 34, comp_score = 0.916, mean_score = 0.818,


                                                    

fold 0 epoch 35, comp_score = 0.925, mean_score = 0.831,


                                                    

fold 0 epoch 36, comp_score = 0.924, mean_score = 0.829,


                                                    

fold 0 epoch 37, comp_score = 0.926, mean_score = 0.832,


                                                    

fold 0 epoch 38, comp_score = 0.926, mean_score = 0.827,


                                                    

fold 0 epoch 39, comp_score = 0.923, mean_score = 0.828,


                                                    

fold 0 epoch 40, comp_score = 0.926, mean_score = 0.838,


                                                    

fold 0 epoch 41, comp_score = 0.922, mean_score = 0.828,


                                                    

fold 0 epoch 42, comp_score = 0.924, mean_score = 0.826,


                                                    

fold 0 epoch 43, comp_score = 0.917, mean_score = 0.821,


                                                    

fold 0 epoch 44, comp_score = 0.916, mean_score = 0.819,


                                                    

fold 0 epoch 45, comp_score = 0.915, mean_score = 0.805,


                                                    

fold 0 epoch 46, comp_score = 0.914, mean_score = 0.812,


                                                    

fold 0 epoch 47, comp_score = 0.914, mean_score = 0.809,


                                                    

fold 0 epoch 48, comp_score = 0.914, mean_score = 0.807,


                                                    

fold 0 epoch 49, comp_score = 0.914, mean_score = 0.806,


                                                    

fold 0 epoch 50, comp_score = 0.922, mean_score = 0.810,


                                                    

fold 0 epoch 51, comp_score = 0.919, mean_score = 0.811,


                                                    

fold 0 epoch 52, comp_score = 0.912, mean_score = 0.804,


                                                    

fold 0 epoch 53, comp_score = 0.916, mean_score = 0.809,


                                                    

fold 0 epoch 54, comp_score = 0.913, mean_score = 0.803,


                                                    

fold 0 epoch 55, comp_score = 0.916, mean_score = 0.809,


                                                    

fold 0 epoch 56, comp_score = 0.922, mean_score = 0.818,


                                                    

fold 0 epoch 57, comp_score = 0.917, mean_score = 0.813,


                                                    

fold 0 epoch 58, comp_score = 0.915, mean_score = 0.816,


                                                    

fold 0 epoch 59, comp_score = 0.917, mean_score = 0.815,


                                                    

fold 0 epoch 60, comp_score = 0.914, mean_score = 0.809,


                                                    

fold 0 epoch 61, comp_score = 0.917, mean_score = 0.813,


                                                    

fold 0 epoch 62, comp_score = 0.915, mean_score = 0.810,


                                                    

fold 0 epoch 63, comp_score = 0.918, mean_score = 0.808,


                                                    

fold 0 epoch 64, comp_score = 0.919, mean_score = 0.809,


                                                    

fold 0 epoch 65, comp_score = 0.915, mean_score = 0.805,


                                                    

fold 0 epoch 66, comp_score = 0.916, mean_score = 0.806,


                                                    

fold 0 epoch 67, comp_score = 0.921, mean_score = 0.811,


                                                    

fold 0 epoch 68, comp_score = 0.923, mean_score = 0.814,


                                                    

fold 0 epoch 69, comp_score = 0.914, mean_score = 0.802,


                                                    

fold 0 epoch 70, comp_score = 0.917, mean_score = 0.809,


                                                    

fold 0 epoch 71, comp_score = 0.920, mean_score = 0.814,


                                                    

fold 0 epoch 72, comp_score = 0.918, mean_score = 0.813,


                                                    

fold 0 epoch 73, comp_score = 0.919, mean_score = 0.813,


                                                    

fold 0 epoch 74, comp_score = 0.921, mean_score = 0.818,


                                                    

fold 0 epoch 75, comp_score = 0.919, mean_score = 0.814,


                                                    

fold 0 epoch 76, comp_score = 0.920, mean_score = 0.820,


                                                    

fold 0 epoch 77, comp_score = 0.920, mean_score = 0.821,


                                                    

fold 0 epoch 78, comp_score = 0.921, mean_score = 0.822,


                                                    

fold 0 epoch 79, comp_score = 0.921, mean_score = 0.822,


                                                    

fold 0 epoch 80, comp_score = 0.922, mean_score = 0.824,


                                                    

fold 0 epoch 81, comp_score = 0.921, mean_score = 0.823,


                                                    

fold 0 epoch 82, comp_score = 0.922, mean_score = 0.825,


                                                    

fold 0 epoch 83, comp_score = 0.923, mean_score = 0.824,


                                                    

fold 0 epoch 84, comp_score = 0.923, mean_score = 0.823,


                                                    

fold 0 epoch 85, comp_score = 0.922, mean_score = 0.822,


                                                    

fold 0 epoch 86, comp_score = 0.922, mean_score = 0.818,


                                                    

fold 0 epoch 87, comp_score = 0.921, mean_score = 0.811,


                                                    

fold 0 epoch 88, comp_score = 0.921, mean_score = 0.810,


                                                    

fold 0 epoch 89, comp_score = 0.922, mean_score = 0.810,


                                                    

fold 1 epoch 0, comp_score = 0.907, mean_score = 0.787,


                                                    

fold 1 epoch 1, comp_score = 0.910, mean_score = 0.793,


                                                    

fold 1 epoch 2, comp_score = 0.909, mean_score = 0.794,


                                                    

fold 1 epoch 3, comp_score = 0.911, mean_score = 0.796,


                                                    

fold 1 epoch 4, comp_score = 0.906, mean_score = 0.789,


                                                    

fold 1 epoch 5, comp_score = 0.889, mean_score = 0.773,


                                                    

fold 1 epoch 6, comp_score = 0.892, mean_score = 0.785,


                                                    

fold 1 epoch 7, comp_score = 0.880, mean_score = 0.773,


                                                    

fold 1 epoch 8, comp_score = 0.895, mean_score = 0.785,


                                                    

fold 1 epoch 9, comp_score = 0.888, mean_score = 0.790,


                                                    

fold 1 epoch 10, comp_score = 0.893, mean_score = 0.794,


                                                    

fold 1 epoch 11, comp_score = 0.895, mean_score = 0.789,


                                                    

fold 1 epoch 12, comp_score = 0.889, mean_score = 0.781,


                                                    

fold 1 epoch 13, comp_score = 0.908, mean_score = 0.795,


                                                    

fold 1 epoch 14, comp_score = 0.891, mean_score = 0.786,


                                                    

fold 1 epoch 15, comp_score = 0.882, mean_score = 0.769,


                                                    

fold 1 epoch 16, comp_score = 0.895, mean_score = 0.799,


                                                    

fold 1 epoch 17, comp_score = 0.896, mean_score = 0.784,


                                                    

fold 1 epoch 18, comp_score = 0.901, mean_score = 0.795,


                                                    

fold 1 epoch 19, comp_score = 0.904, mean_score = 0.794,


                                                    

fold 1 epoch 20, comp_score = 0.900, mean_score = 0.797,


                                                    

fold 1 epoch 21, comp_score = 0.910, mean_score = 0.813,


                                                    

fold 1 epoch 22, comp_score = 0.898, mean_score = 0.802,


                                                    

fold 1 epoch 23, comp_score = 0.910, mean_score = 0.812,


                                                    

fold 1 epoch 24, comp_score = 0.898, mean_score = 0.793,


                                                    

fold 1 epoch 25, comp_score = 0.890, mean_score = 0.777,


                                                    

fold 1 epoch 26, comp_score = 0.905, mean_score = 0.798,


                                                    

fold 1 epoch 27, comp_score = 0.903, mean_score = 0.804,


                                                    

fold 1 epoch 28, comp_score = 0.904, mean_score = 0.797,


                                                    

fold 1 epoch 29, comp_score = 0.903, mean_score = 0.796,


                                                    

fold 1 epoch 30, comp_score = 0.901, mean_score = 0.795,


                                                    

fold 1 epoch 31, comp_score = 0.899, mean_score = 0.794,


                                                    

fold 1 epoch 32, comp_score = 0.900, mean_score = 0.791,


                                                    

fold 1 epoch 33, comp_score = 0.893, mean_score = 0.792,


                                                    

fold 1 epoch 34, comp_score = 0.905, mean_score = 0.802,


                                                    

fold 1 epoch 35, comp_score = 0.903, mean_score = 0.797,


                                                    

fold 1 epoch 36, comp_score = 0.906, mean_score = 0.798,


                                                    

fold 1 epoch 37, comp_score = 0.905, mean_score = 0.794,


                                                    

fold 1 epoch 38, comp_score = 0.906, mean_score = 0.805,


                                                    

fold 1 epoch 39, comp_score = 0.889, mean_score = 0.786,


                                                    

fold 1 epoch 40, comp_score = 0.904, mean_score = 0.806,


                                                    

fold 1 epoch 41, comp_score = 0.902, mean_score = 0.797,


                                                    

fold 1 epoch 42, comp_score = 0.891, mean_score = 0.787,


                                                    

fold 1 epoch 43, comp_score = 0.901, mean_score = 0.795,


                                                    

fold 1 epoch 44, comp_score = 0.900, mean_score = 0.793,


                                                    

fold 1 epoch 45, comp_score = 0.899, mean_score = 0.794,


                                                    

fold 1 epoch 46, comp_score = 0.896, mean_score = 0.787,


                                                    

fold 1 epoch 47, comp_score = 0.887, mean_score = 0.777,


                                                    

fold 1 epoch 48, comp_score = 0.897, mean_score = 0.786,


                                                    

fold 1 epoch 49, comp_score = 0.890, mean_score = 0.777,


                                                    

fold 1 epoch 50, comp_score = 0.891, mean_score = 0.775,


                                                    

fold 1 epoch 51, comp_score = 0.895, mean_score = 0.784,


                                                    

fold 1 epoch 52, comp_score = 0.898, mean_score = 0.783,


                                                    

fold 1 epoch 53, comp_score = 0.896, mean_score = 0.781,


                                                    

fold 1 epoch 54, comp_score = 0.894, mean_score = 0.781,


                                                    

fold 1 epoch 55, comp_score = 0.895, mean_score = 0.787,


                                                    

fold 1 epoch 56, comp_score = 0.894, mean_score = 0.783,


                                                    

fold 1 epoch 57, comp_score = 0.890, mean_score = 0.784,


                                                    

fold 1 epoch 58, comp_score = 0.885, mean_score = 0.775,


                                                    

fold 1 epoch 59, comp_score = 0.884, mean_score = 0.769,


                                                    

fold 1 epoch 60, comp_score = 0.890, mean_score = 0.778,


                                                    

fold 1 epoch 61, comp_score = 0.894, mean_score = 0.789,


                                                    

fold 1 epoch 62, comp_score = 0.889, mean_score = 0.786,


                                                    

fold 1 epoch 63, comp_score = 0.890, mean_score = 0.786,


                                                    

fold 1 epoch 64, comp_score = 0.891, mean_score = 0.787,


                                                    

fold 1 epoch 65, comp_score = 0.893, mean_score = 0.784,


                                                    

fold 1 epoch 66, comp_score = 0.895, mean_score = 0.788,


                                                    

fold 1 epoch 67, comp_score = 0.893, mean_score = 0.784,


                                                    

fold 1 epoch 68, comp_score = 0.893, mean_score = 0.785,


                                                    

fold 1 epoch 69, comp_score = 0.895, mean_score = 0.785,


                                                    

fold 1 epoch 70, comp_score = 0.897, mean_score = 0.788,


                                                    

fold 1 epoch 71, comp_score = 0.895, mean_score = 0.789,


                                                    

fold 1 epoch 72, comp_score = 0.895, mean_score = 0.788,


                                                    

fold 1 epoch 73, comp_score = 0.895, mean_score = 0.791,


                                                    

fold 1 epoch 74, comp_score = 0.898, mean_score = 0.791,


                                                    

fold 1 epoch 75, comp_score = 0.896, mean_score = 0.788,


                                                    

fold 1 epoch 76, comp_score = 0.895, mean_score = 0.788,


                                                    

fold 1 epoch 77, comp_score = 0.896, mean_score = 0.789,


                                                    

fold 1 epoch 78, comp_score = 0.895, mean_score = 0.789,


                                                    

fold 1 epoch 79, comp_score = 0.896, mean_score = 0.788,


                                                    

fold 1 epoch 80, comp_score = 0.895, mean_score = 0.789,


                                                    

fold 1 epoch 81, comp_score = 0.892, mean_score = 0.788,


                                                    

fold 1 epoch 82, comp_score = 0.893, mean_score = 0.786,


                                                    

fold 1 epoch 83, comp_score = 0.894, mean_score = 0.789,


                                                    

fold 1 epoch 84, comp_score = 0.891, mean_score = 0.789,


                                                    

fold 1 epoch 85, comp_score = 0.891, mean_score = 0.789,


                                                    

fold 1 epoch 86, comp_score = 0.892, mean_score = 0.790,


                                                    

fold 1 epoch 87, comp_score = 0.893, mean_score = 0.790,


                                                    

fold 1 epoch 88, comp_score = 0.893, mean_score = 0.791,


                                                    

fold 1 epoch 89, comp_score = 0.893, mean_score = 0.791,


                                                    

fold 2 epoch 0, comp_score = 0.900, mean_score = 0.795,


                                                    

fold 2 epoch 1, comp_score = 0.880, mean_score = 0.768,


                                                    

fold 2 epoch 2, comp_score = 0.883, mean_score = 0.778,


                                                    

fold 2 epoch 3, comp_score = 0.895, mean_score = 0.787,


                                                    

fold 2 epoch 4, comp_score = 0.903, mean_score = 0.803,


                                                    

fold 2 epoch 5, comp_score = 0.910, mean_score = 0.824,


                                                    

fold 2 epoch 6, comp_score = 0.909, mean_score = 0.819,


                                                    

fold 2 epoch 7, comp_score = 0.906, mean_score = 0.810,


                                                    

fold 2 epoch 8, comp_score = 0.904, mean_score = 0.805,


                                                    

fold 2 epoch 9, comp_score = 0.904, mean_score = 0.806,


                                                    

fold 2 epoch 10, comp_score = 0.904, mean_score = 0.802,


                                                    

fold 2 epoch 11, comp_score = 0.907, mean_score = 0.807,


                                                    

fold 2 epoch 12, comp_score = 0.910, mean_score = 0.811,


                                                    

fold 2 epoch 13, comp_score = 0.906, mean_score = 0.807,


                                                    

fold 2 epoch 14, comp_score = 0.906, mean_score = 0.810,


                                                    

fold 2 epoch 15, comp_score = 0.902, mean_score = 0.811,


                                                    

fold 2 epoch 16, comp_score = 0.907, mean_score = 0.811,


                                                    

fold 2 epoch 17, comp_score = 0.902, mean_score = 0.810,


                                                    

fold 2 epoch 18, comp_score = 0.904, mean_score = 0.809,


                                                    

fold 2 epoch 19, comp_score = 0.904, mean_score = 0.804,


                                                    

fold 2 epoch 20, comp_score = 0.905, mean_score = 0.813,


                                                    

fold 2 epoch 21, comp_score = 0.909, mean_score = 0.814,


                                                    

fold 2 epoch 22, comp_score = 0.902, mean_score = 0.805,


                                                    

fold 2 epoch 23, comp_score = 0.908, mean_score = 0.811,


                                                    

fold 2 epoch 24, comp_score = 0.911, mean_score = 0.813,


                                                    

fold 2 epoch 25, comp_score = 0.889, mean_score = 0.787,


                                                    

fold 2 epoch 26, comp_score = 0.912, mean_score = 0.820,


                                                    

fold 2 epoch 27, comp_score = 0.911, mean_score = 0.815,


                                                    

fold 2 epoch 28, comp_score = 0.909, mean_score = 0.812,


                                                    

fold 2 epoch 29, comp_score = 0.912, mean_score = 0.810,


                                                    

fold 2 epoch 30, comp_score = 0.911, mean_score = 0.816,


                                                    

fold 2 epoch 31, comp_score = 0.909, mean_score = 0.816,


                                                    

fold 2 epoch 32, comp_score = 0.909, mean_score = 0.811,


                                                    

fold 2 epoch 33, comp_score = 0.906, mean_score = 0.807,


                                                    

fold 2 epoch 34, comp_score = 0.909, mean_score = 0.812,


                                                    

fold 2 epoch 35, comp_score = 0.916, mean_score = 0.825,


                                                    

fold 2 epoch 36, comp_score = 0.911, mean_score = 0.822,


                                                    

fold 2 epoch 37, comp_score = 0.912, mean_score = 0.822,


                                                    

fold 2 epoch 38, comp_score = 0.914, mean_score = 0.821,


                                                    

fold 2 epoch 39, comp_score = 0.912, mean_score = 0.819,


                                                    

fold 2 epoch 40, comp_score = 0.913, mean_score = 0.820,


                                                    

fold 2 epoch 41, comp_score = 0.913, mean_score = 0.827,


                                                    

fold 2 epoch 42, comp_score = 0.914, mean_score = 0.824,


                                                    

fold 2 epoch 43, comp_score = 0.914, mean_score = 0.826,


                                                    

fold 2 epoch 44, comp_score = 0.914, mean_score = 0.825,


                                                    

fold 2 epoch 45, comp_score = 0.915, mean_score = 0.825,


                                                    

fold 2 epoch 46, comp_score = 0.915, mean_score = 0.827,


                                                    

fold 2 epoch 47, comp_score = 0.915, mean_score = 0.829,


                                                    

fold 2 epoch 48, comp_score = 0.914, mean_score = 0.829,


                                                    

fold 2 epoch 49, comp_score = 0.915, mean_score = 0.826,


                                                    

fold 2 epoch 50, comp_score = 0.914, mean_score = 0.826,


                                                    

fold 2 epoch 51, comp_score = 0.913, mean_score = 0.824,


                                                    

fold 2 epoch 52, comp_score = 0.913, mean_score = 0.826,


                                                    

fold 2 epoch 53, comp_score = 0.914, mean_score = 0.826,


                                                    

fold 2 epoch 54, comp_score = 0.911, mean_score = 0.825,


                                                    

fold 2 epoch 55, comp_score = 0.913, mean_score = 0.826,


                                                    

fold 2 epoch 56, comp_score = 0.912, mean_score = 0.824,


                                                    

fold 2 epoch 57, comp_score = 0.908, mean_score = 0.821,


                                                    

fold 2 epoch 58, comp_score = 0.909, mean_score = 0.821,


                                                    

fold 2 epoch 59, comp_score = 0.910, mean_score = 0.820,


                                                    

fold 2 epoch 60, comp_score = 0.909, mean_score = 0.818,


                                                    

fold 2 epoch 61, comp_score = 0.911, mean_score = 0.816,


                                                    

fold 2 epoch 62, comp_score = 0.904, mean_score = 0.813,


                                                    

fold 2 epoch 63, comp_score = 0.904, mean_score = 0.810,


                                                    

fold 2 epoch 64, comp_score = 0.910, mean_score = 0.811,


                                                    

fold 2 epoch 65, comp_score = 0.915, mean_score = 0.813,


                                                    

fold 2 epoch 66, comp_score = 0.915, mean_score = 0.808,


                                                    

fold 2 epoch 67, comp_score = 0.916, mean_score = 0.812,


                                                    

fold 2 epoch 68, comp_score = 0.912, mean_score = 0.805,


                                                    

fold 2 epoch 69, comp_score = 0.911, mean_score = 0.805,


                                                    

fold 2 epoch 70, comp_score = 0.913, mean_score = 0.798,


                                                    

fold 2 epoch 71, comp_score = 0.911, mean_score = 0.798,


                                                    

fold 2 epoch 72, comp_score = 0.910, mean_score = 0.798,


                                                    

fold 2 epoch 73, comp_score = 0.906, mean_score = 0.790,


                                                    

fold 2 epoch 74, comp_score = 0.907, mean_score = 0.789,


                                                    

fold 2 epoch 75, comp_score = 0.908, mean_score = 0.790,


                                                    

fold 2 epoch 76, comp_score = 0.908, mean_score = 0.792,


                                                    

fold 2 epoch 77, comp_score = 0.908, mean_score = 0.794,


                                                    

fold 2 epoch 78, comp_score = 0.908, mean_score = 0.797,


                                                    

fold 2 epoch 79, comp_score = 0.906, mean_score = 0.796,


                                                    

fold 2 epoch 80, comp_score = 0.908, mean_score = 0.799,


                                                    

fold 2 epoch 81, comp_score = 0.906, mean_score = 0.797,


                                                    

fold 2 epoch 82, comp_score = 0.908, mean_score = 0.798,


                                                    

fold 2 epoch 83, comp_score = 0.907, mean_score = 0.797,


                                                    

fold 2 epoch 84, comp_score = 0.904, mean_score = 0.793,


                                                    

fold 2 epoch 85, comp_score = 0.903, mean_score = 0.790,


                                                    

fold 2 epoch 86, comp_score = 0.904, mean_score = 0.792,


                                                    

fold 2 epoch 87, comp_score = 0.902, mean_score = 0.791,


                                                    

fold 2 epoch 88, comp_score = 0.901, mean_score = 0.790,


                                                    

fold 2 epoch 89, comp_score = 0.902, mean_score = 0.790,


                                                    

fold 3 epoch 0, comp_score = 0.905, mean_score = 0.779,


                                                    

fold 3 epoch 1, comp_score = 0.900, mean_score = 0.770,


                                                    

fold 3 epoch 2, comp_score = 0.906, mean_score = 0.810,


                                                    

fold 3 epoch 3, comp_score = 0.889, mean_score = 0.772,


                                                    

fold 3 epoch 4, comp_score = 0.893, mean_score = 0.792,


                                                    

fold 3 epoch 5, comp_score = 0.908, mean_score = 0.811,


                                                    

fold 3 epoch 6, comp_score = 0.907, mean_score = 0.811,


                                                    

fold 3 epoch 7, comp_score = 0.907, mean_score = 0.820,


                                                    

fold 3 epoch 8, comp_score = 0.920, mean_score = 0.831,


                                                    

fold 3 epoch 9, comp_score = 0.915, mean_score = 0.828,


                                                    

fold 3 epoch 10, comp_score = 0.908, mean_score = 0.815,


                                                    

fold 3 epoch 11, comp_score = 0.919, mean_score = 0.830,


                                                    

fold 3 epoch 12, comp_score = 0.912, mean_score = 0.826,


                                                    

fold 3 epoch 13, comp_score = 0.915, mean_score = 0.813,


                                                    

fold 3 epoch 14, comp_score = 0.915, mean_score = 0.824,


                                                    

fold 3 epoch 15, comp_score = 0.914, mean_score = 0.828,


                                                    

fold 3 epoch 16, comp_score = 0.916, mean_score = 0.822,


                                                    

fold 3 epoch 17, comp_score = 0.913, mean_score = 0.825,


                                                    

fold 3 epoch 18, comp_score = 0.914, mean_score = 0.822,


                                                    

fold 3 epoch 19, comp_score = 0.913, mean_score = 0.808,


                                                    

fold 3 epoch 20, comp_score = 0.915, mean_score = 0.796,


                                                    

fold 3 epoch 21, comp_score = 0.911, mean_score = 0.816,


                                                    

fold 3 epoch 22, comp_score = 0.908, mean_score = 0.768,


                                                    

fold 3 epoch 23, comp_score = 0.915, mean_score = 0.790,


                                                    

fold 3 epoch 24, comp_score = 0.917, mean_score = 0.799,


                                                    

fold 3 epoch 25, comp_score = 0.924, mean_score = 0.829,


                                                    

fold 3 epoch 26, comp_score = 0.918, mean_score = 0.790,


                                                    

fold 3 epoch 27, comp_score = 0.922, mean_score = 0.799,


                                                    

fold 3 epoch 28, comp_score = 0.910, mean_score = 0.812,


                                                    

fold 3 epoch 29, comp_score = 0.917, mean_score = 0.824,


                                                    

fold 3 epoch 30, comp_score = 0.908, mean_score = 0.803,


                                                    

fold 3 epoch 31, comp_score = 0.912, mean_score = 0.813,


                                                    

fold 3 epoch 32, comp_score = 0.921, mean_score = 0.831,


                                                    

fold 3 epoch 33, comp_score = 0.913, mean_score = 0.810,


                                                    

fold 3 epoch 34, comp_score = 0.907, mean_score = 0.809,


                                                    

fold 3 epoch 35, comp_score = 0.895, mean_score = 0.793,


                                                    

fold 3 epoch 36, comp_score = 0.923, mean_score = 0.827,


                                                    

fold 3 epoch 37, comp_score = 0.910, mean_score = 0.814,


                                                    

fold 3 epoch 38, comp_score = 0.916, mean_score = 0.825,


                                                    

fold 3 epoch 39, comp_score = 0.915, mean_score = 0.823,


                                                    

fold 3 epoch 40, comp_score = 0.921, mean_score = 0.830,


                                                    

fold 3 epoch 41, comp_score = 0.914, mean_score = 0.814,


                                                    

fold 3 epoch 42, comp_score = 0.916, mean_score = 0.820,


                                                    

fold 3 epoch 43, comp_score = 0.921, mean_score = 0.826,


                                                    

fold 3 epoch 44, comp_score = 0.921, mean_score = 0.831,


                                                    

fold 3 epoch 45, comp_score = 0.911, mean_score = 0.812,


                                                   

fold 3 epoch 46, comp_score = 0.921, mean_score = 0.822,


                                                    

fold 3 epoch 48, comp_score = 0.922, mean_score = 0.832,


                                                    

fold 3 epoch 49, comp_score = 0.923, mean_score = 0.834,


                                                    

fold 3 epoch 50, comp_score = 0.924, mean_score = 0.830,


                                                    

fold 3 epoch 51, comp_score = 0.924, mean_score = 0.833,


                                                    

fold 3 epoch 52, comp_score = 0.925, mean_score = 0.838,


                                                    

fold 3 epoch 53, comp_score = 0.921, mean_score = 0.834,


                                                    

fold 3 epoch 54, comp_score = 0.924, mean_score = 0.831,


                                                    

fold 3 epoch 55, comp_score = 0.917, mean_score = 0.817,


                                                    

fold 3 epoch 56, comp_score = 0.917, mean_score = 0.822,


                                                    

fold 3 epoch 57, comp_score = 0.915, mean_score = 0.828,


                                                    

fold 3 epoch 58, comp_score = 0.908, mean_score = 0.812,


                                                    

fold 3 epoch 59, comp_score = 0.907, mean_score = 0.818,


                                                    

fold 3 epoch 60, comp_score = 0.906, mean_score = 0.816,


                                                    

fold 3 epoch 61, comp_score = 0.907, mean_score = 0.815,


                                                    

fold 3 epoch 62, comp_score = 0.910, mean_score = 0.818,


                                                    

fold 3 epoch 63, comp_score = 0.908, mean_score = 0.815,


                                                    

fold 3 epoch 64, comp_score = 0.912, mean_score = 0.818,


                                                    

fold 3 epoch 65, comp_score = 0.909, mean_score = 0.810,


                                                    

fold 3 epoch 66, comp_score = 0.906, mean_score = 0.803,


                                                    

fold 3 epoch 67, comp_score = 0.912, mean_score = 0.807,


                                                    

fold 3 epoch 69, comp_score = 0.908, mean_score = 0.804,


                                                    

fold 3 epoch 70, comp_score = 0.908, mean_score = 0.805,


                                                    

fold 3 epoch 71, comp_score = 0.906, mean_score = 0.804,


                                                    

fold 3 epoch 72, comp_score = 0.905, mean_score = 0.799,


                                                    

fold 3 epoch 73, comp_score = 0.914, mean_score = 0.804,


                                                    

fold 3 epoch 74, comp_score = 0.915, mean_score = 0.805,


                                                    

fold 3 epoch 75, comp_score = 0.917, mean_score = 0.808,


                                                    

fold 3 epoch 76, comp_score = 0.918, mean_score = 0.811,


                                                    

fold 3 epoch 77, comp_score = 0.916, mean_score = 0.809,


                                                    

fold 3 epoch 78, comp_score = 0.918, mean_score = 0.811,


                                                    

fold 3 epoch 79, comp_score = 0.917, mean_score = 0.811,


                                                    

fold 3 epoch 80, comp_score = 0.918, mean_score = 0.809,


                                                    

fold 3 epoch 81, comp_score = 0.919, mean_score = 0.811,


                                                    

fold 3 epoch 82, comp_score = 0.919, mean_score = 0.813,


                                                    

fold 3 epoch 83, comp_score = 0.919, mean_score = 0.815,


                                                    

fold 3 epoch 84, comp_score = 0.919, mean_score = 0.814,


                                                    

fold 3 epoch 85, comp_score = 0.918, mean_score = 0.814,


                                                    

fold 3 epoch 86, comp_score = 0.919, mean_score = 0.814,


                                                    

fold 3 epoch 87, comp_score = 0.917, mean_score = 0.810,


                                                    

fold 3 epoch 88, comp_score = 0.917, mean_score = 0.809,


                                                    

fold 4 epoch 0, comp_score = 0.915, mean_score = 0.807,


                                                    

fold 4 epoch 1, comp_score = 0.918, mean_score = 0.815,


                                                    

fold 4 epoch 2, comp_score = 0.924, mean_score = 0.827,


                                                    

fold 4 epoch 3, comp_score = 0.914, mean_score = 0.822,


                                                    

fold 4 epoch 4, comp_score = 0.922, mean_score = 0.829,


                                                    

fold 4 epoch 5, comp_score = 0.927, mean_score = 0.836,


                                                    

fold 4 epoch 6, comp_score = 0.929, mean_score = 0.838,


                                                    

fold 4 epoch 7, comp_score = 0.924, mean_score = 0.831,


                                                    

fold 4 epoch 8, comp_score = 0.920, mean_score = 0.811,


                                                    

fold 4 epoch 9, comp_score = 0.915, mean_score = 0.817,


                                                    

fold 4 epoch 10, comp_score = 0.919, mean_score = 0.832,


                                                    

fold 4 epoch 11, comp_score = 0.924, mean_score = 0.832,


                                                    

fold 4 epoch 12, comp_score = 0.921, mean_score = 0.815,


                                                    

fold 4 epoch 13, comp_score = 0.881, mean_score = 0.763,


                                                    

fold 4 epoch 14, comp_score = 0.923, mean_score = 0.826,


                                                    

fold 4 epoch 15, comp_score = 0.920, mean_score = 0.811,


                                                    

fold 4 epoch 16, comp_score = 0.925, mean_score = 0.831,


                                                    

fold 4 epoch 17, comp_score = 0.886, mean_score = 0.772,


                                                    

fold 4 epoch 18, comp_score = 0.916, mean_score = 0.812,


                                                    

fold 4 epoch 19, comp_score = 0.916, mean_score = 0.811,


                                                    

fold 4 epoch 20, comp_score = 0.920, mean_score = 0.821,


                                                    

fold 4 epoch 21, comp_score = 0.913, mean_score = 0.810,


                                                    

fold 4 epoch 22, comp_score = 0.922, mean_score = 0.819,


                                                    

fold 4 epoch 23, comp_score = 0.925, mean_score = 0.818,


                                                    

fold 4 epoch 24, comp_score = 0.924, mean_score = 0.829,


                                                    

fold 4 epoch 25, comp_score = 0.923, mean_score = 0.825,


                                                    

fold 4 epoch 26, comp_score = 0.914, mean_score = 0.803,


                                                    

fold 4 epoch 27, comp_score = 0.910, mean_score = 0.808,


                                                    

fold 4 epoch 28, comp_score = 0.914, mean_score = 0.813,


                                                    

fold 4 epoch 29, comp_score = 0.915, mean_score = 0.812,


                                                    

fold 4 epoch 30, comp_score = 0.922, mean_score = 0.827,


                                                    

fold 4 epoch 31, comp_score = 0.918, mean_score = 0.822,


                                                    

fold 4 epoch 32, comp_score = 0.917, mean_score = 0.815,


                                                    

fold 4 epoch 33, comp_score = 0.913, mean_score = 0.816,


                                                    

fold 4 epoch 34, comp_score = 0.918, mean_score = 0.812,


                                                    

fold 4 epoch 35, comp_score = 0.902, mean_score = 0.788,


                                                    

fold 4 epoch 36, comp_score = 0.915, mean_score = 0.820,


                                                    

fold 4 epoch 39, comp_score = 0.908, mean_score = 0.804,


                                                    

fold 4 epoch 40, comp_score = 0.911, mean_score = 0.813,


                                                    

fold 4 epoch 41, comp_score = 0.911, mean_score = 0.815,


                                                    

fold 4 epoch 42, comp_score = 0.909, mean_score = 0.812,


                                                    

fold 4 epoch 43, comp_score = 0.911, mean_score = 0.817,


                                                    

fold 4 epoch 44, comp_score = 0.907, mean_score = 0.810,


                                                    

fold 4 epoch 45, comp_score = 0.900, mean_score = 0.794,


                                                    

fold 4 epoch 46, comp_score = 0.898, mean_score = 0.802,


                                                    

fold 4 epoch 47, comp_score = 0.910, mean_score = 0.812,


                                                    

fold 4 epoch 48, comp_score = 0.904, mean_score = 0.808,


                                                    

fold 4 epoch 49, comp_score = 0.909, mean_score = 0.816,


                                                    

fold 4 epoch 50, comp_score = 0.903, mean_score = 0.807,


                                                    

fold 4 epoch 51, comp_score = 0.900, mean_score = 0.803,


                                                    

fold 4 epoch 52, comp_score = 0.903, mean_score = 0.811,


                                                    

fold 4 epoch 53, comp_score = 0.898, mean_score = 0.807,


                                                    

fold 4 epoch 54, comp_score = 0.903, mean_score = 0.811,


                                                    

fold 4 epoch 55, comp_score = 0.901, mean_score = 0.808,


                                                    

fold 4 epoch 56, comp_score = 0.898, mean_score = 0.807,


                                                    

fold 4 epoch 57, comp_score = 0.897, mean_score = 0.804,


                                                    

fold 4 epoch 58, comp_score = 0.899, mean_score = 0.802,


                                                    

fold 4 epoch 59, comp_score = 0.904, mean_score = 0.809,


                                                    

fold 4 epoch 60, comp_score = 0.905, mean_score = 0.808,


                                                    

fold 4 epoch 61, comp_score = 0.908, mean_score = 0.808,


                                                    

fold 4 epoch 62, comp_score = 0.907, mean_score = 0.811,


                                                    

fold 4 epoch 63, comp_score = 0.913, mean_score = 0.811,


                                                    

fold 4 epoch 64, comp_score = 0.911, mean_score = 0.811,


                                                    

fold 4 epoch 65, comp_score = 0.906, mean_score = 0.800,


                                                    

fold 4 epoch 66, comp_score = 0.910, mean_score = 0.804,


                                                    

fold 4 epoch 67, comp_score = 0.905, mean_score = 0.802,


                                                    

fold 4 epoch 68, comp_score = 0.905, mean_score = 0.798,


                                                    

fold 4 epoch 69, comp_score = 0.910, mean_score = 0.805,


                                                    

fold 4 epoch 70, comp_score = 0.909, mean_score = 0.797,


                                                    

fold 4 epoch 71, comp_score = 0.905, mean_score = 0.793,


                                                    

fold 4 epoch 72, comp_score = 0.906, mean_score = 0.792,


                                                    

fold 4 epoch 73, comp_score = 0.907, mean_score = 0.790,


                                                    

fold 4 epoch 74, comp_score = 0.907, mean_score = 0.789,


                                                    

fold 4 epoch 75, comp_score = 0.905, mean_score = 0.788,


                                                    

fold 4 epoch 78, comp_score = 0.908, mean_score = 0.801,


                                                    

fold 4 epoch 79, comp_score = 0.905, mean_score = 0.802,


                                                    

fold 4 epoch 80, comp_score = 0.904, mean_score = 0.800,


                                                    

fold 4 epoch 81, comp_score = 0.903, mean_score = 0.801,


                                                    

fold 4 epoch 82, comp_score = 0.908, mean_score = 0.804,


                                                    

fold 4 epoch 83, comp_score = 0.909, mean_score = 0.809,


                                                    

fold 4 epoch 84, comp_score = 0.910, mean_score = 0.810,


                                                    

fold 4 epoch 85, comp_score = 0.912, mean_score = 0.812,


                                                    

fold 4 epoch 86, comp_score = 0.912, mean_score = 0.810,


                                                    

fold 4 epoch 87, comp_score = 0.911, mean_score = 0.809,


                                                    

fold 4 epoch 88, comp_score = 0.912, mean_score = 0.808,


                                                    

fold 4 epoch 89, comp_score = 0.911, mean_score = 0.807,
comp_score = 0.9217602880169198, mean_score = 0.8337294808782281,


# 📊 Evaluate on Validation Dataset

# 🚀 Predict and Submit (only tile:xla predictions)

In [8]:
# 📊 Predict on Test Dataset (tile:xla)
# In this section, we use the trained model to make predictions on the test dataset ('tile:xla').

# Create a TileDataset for the 'tile:xla' test dataset
dataset = TileDataset(tile_xla["test"])

# List to store model predictions for each sample in the test dataset
tile_xla_predictions = [[] for i in range(len(dataset))]

# Iterate through each fold (previously trained models)
for fold in range(5):
    # Load the trained model weights for the current fold
    model.load_state_dict(torch.load(f'/kaggle/working/best_model_{fold}.pth'))
    model.eval()  # 🕵️ Set the model to evaluation mode
    pbar = tqdm(range(len(dataset)))  # Progress bar for test data prediction
    
    for i in pbar:
        cfg_ft, nd_ft, nd_op, ind, target = dataset[i]
        cfg_ft, nd_ft, nd_op, ind, target = cfg_ft.to(device), nd_ft.to(device), nd_op.to(device), ind.to(device), target.to(device)

        out = model(cfg_ft, nd_ft, nd_op, ind)
        tile_xla_predictions[i].append(out.cpu().detach().numpy())

# Aggregate predictions by taking the mean and selecting the top 5
tile_xla_predictions = [np.argsort(np.mean(pred, axis=0))[:5] for pred in tile_xla_predictions]

# The 'tile_xla_predictions' now contains the top 5 predicted results for each sample in the 'tile:xla' test dataset.
tile_xla_predictions


100%|██████████| 844/844 [00:04<00:00, 202.04it/s]
100%|██████████| 844/844 [00:03<00:00, 227.81it/s]
100%|██████████| 844/844 [00:03<00:00, 219.03it/s]
100%|██████████| 844/844 [00:03<00:00, 230.62it/s]
100%|██████████| 844/844 [00:03<00:00, 226.66it/s]


[array([746, 554, 127, 408, 731]),
 array([ 415, 5420, 7006, 6480, 5213]),
 array([1344,  161,  206, 1409,  935]),
 array([158, 210, 212,  79, 120]),
 array([222, 164, 124,  31, 139]),
 array([ 611,  565, 1111,  719,  912]),
 array([184, 216, 146, 158, 221]),
 array([6768, 3649,  219,  323, 4885]),
 array([ 12,  49,  16, 115,  27]),
 array([ 74, 123,   2, 115,  61]),
 array([ 84, 519, 691, 121, 113]),
 array([ 94, 196, 241, 403, 363]),
 array([111,  30,  28,  78, 107]),
 array([1558, 7899, 1335, 8265, 7565]),
 array([59, 62, 63, 92, 99]),
 array([124, 132, 129,   7,   8]),
 array([  24,  262, 1495,  979,  521]),
 array([5576, 3856, 5896,  302, 2465]),
 array([229, 180, 437, 384, 339]),
 array([167,  17, 430, 360, 143]),
 array([69, 13,  0,  2, 73]),
 array([ 26, 222, 491, 383, 304]),
 array([5674, 6096,  486, 6254,  886]),
 array([20, 22,  9, 10,  2]),
 array([2199, 8680, 4033, 5925, 1843]),
 array([295, 327, 509, 677, 329]),
 array([  2, 208, 138, 227,  77]),
 array([2513, 1970,  675,

In [9]:
# 📊 Generate and Save Submission File
# In this section, we generate a submission file based on the model predictions and save it.

# Read the sample submission file
sub = pd.read_csv('/kaggle/input/predict-ai-model-runtime/sample_submission.csv')

# Iterate through the test file names and update the submission file with top predictions
for i, filename in enumerate(tile_xla["test"]['file'].values):
    id = 'tile:xla:' + filename[:-4]  # Construct the ID for the submission
    sub.loc[sub.ID == id, 'TopConfigs'] = ';'.join(tile_xla_predictions[i].astype(str))

# Save the updated submission file as 'submission.csv' without the index
sub = sub[sub.ID.str.contains('tile')].reset_index(drop=True)
sub.to_csv('inference_tile_xla.csv', index=False)

# Display the updated submission file
sub


Unnamed: 0,ID,TopConfigs
0,tile:xla:d6f5f54247bd1e58a10b9e7062c636ab,0;22;21;20;19
1,tile:xla:e3a655daa38e34ec240df959b650ac16,124;802;1301;711;385
2,tile:xla:f8c2c1a1098b2a361c26df668b286c87,27;129;12;202;2
3,tile:xla:4dd1716853ed46ee4e7d09ede1732de8,903;1015;3939;7320;4945
4,tile:xla:d0a69155b6340748c36724e4bfc34be3,641;264;252;640;248
...,...,...
889,layout:nlp:random:60880ed76de53f4d7a1b960b24f2...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
890,layout:nlp:random:23559853d9702baaaacbb0c83fd3...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
891,layout:nlp:random:f6c146fc5cf10be4f3accbaca989...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
892,layout:nlp:random:32531d07a084b319dce484f53a4c...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
