In [1]:
import torch
import wandb
import numpy as np
from torch import nn
from torch import optim
from torch_geometric.nn import GCNConv, GATConv
from torch.utils.data import Dataset
from tqdm import tqdm
from sklearn.model_selection import KFold

from gfos.data.utils import load_layout
from gfos.utils.scheduler import CosineAnnealingWarmupRestarts
from gfos.metrics import metric_for_layout_collections
from gfos.loss import MultiElementRankLoss

device = "cuda" if torch.cuda.is_available() else "cpu"


## Configs

In [2]:
LAYOUT_DIR = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"

configs = dict(
    conv_layer="GAT",
    num_epochs=1000,
    learning_rate=5e-4,
    weight_decay=1e-6,
    min_lr=1e-7,
    warmup_ratio=0.05,
    max_configs=100,
    graph_hidden=[16, 32, 16, 48],
    graph_out=64,
    num_encoder=1,
    num_feedforward=512,
    nhead=1,
    loss_margin=0.1,
    loss_num_permutations=50,
)

WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = "gcn_layout"
TAGS = ["train", "layout"]


## Dataset

In [3]:
class LayoutDataset(Dataset):
    """Take all `c` configs as one batch. Using without dataloader."""

    def __init__(self, files: list[str], max_configs: int = 1000):
        self.max_configs = max_configs
        self.files = files
        self.npzs = [np.load(file) for file in self.files]
        self.num_records = [len(npz["config_runtime"]) for npz in self.npzs]

    def __len__(self):
        return len(self.npzs)

    def __getitem__(self, idx):
        # cum_records = np.cumsum(self.num_records)
        # npz_idx = np.searchsorted(cum_records, idx)
        # row = self.npzs[npz_idx]

        # cfg_idx = idx - cum_records[npz_idx - 1] if npz_idx > 0 else idx
        row = self.npzs[idx]

        target = torch.tensor(row["config_runtime"], dtype=torch.float)
        target = (target - target.mean()) / (target.std() + 1e-7)

        # Random sample `max_configs` configs from `c` configs
        config_indices = torch.randperm(target.size(0))[: self.max_configs]
        target = target[config_indices]

        node_feat = torch.tensor(row["node_feat"], dtype=torch.float)
        node_opcode = torch.tensor(row["node_opcode"], dtype=torch.long)
        edge_index = torch.tensor(
            np.swapaxes(row["edge_index"], 0, 1), dtype=torch.long
        )

        node_config_feat = torch.tensor(
            row["node_config_feat"], dtype=torch.float
        )[config_indices]

        node_config_ids = torch.tensor(
            row["node_config_ids"], dtype=torch.long
        )

        return (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        )

## Model

In [4]:
# def transform_node_positional_embeddings(
#     embeddings_output: torch.Tensor,
#     node_config_ids: torch.Tensor,
#     num_nodes: int,
# ) -> torch.Tensor:
#     num_configs, _, dim = embeddings_output.shape
#     idxs = node_config_ids.unsqueeze(0).repeat(num_configs, 1) # [c, nc]
#     zeros = torch.zeros(
#         num_configs,
#         num_nodes,
#         dim,
#         device=embeddings_output.device,
#         dtype=embeddings_output.dtype,
#     )
#     idxs = idxs.unsqueeze(-1).repeat(1, 1, dim) # [c, nc, dim]
#     zeros.scatter_reduce_(1, idxs, embeddings_output, reduce="sum")
#     return zeros

In [5]:
class SimpleModel(torch.nn.Module):
    def __init__(
        self,
        hidden_channels,
        graph_out,
        num_encoder=1,
        num_feedforward=256,
        nhead=1,
    ):
        super().__init__()

        config_dim = 64
        op_embedding_dim = 4  # I choose 4-dimensional embedding
        merged_node_dim = graph_out + config_dim

        self.embedding = torch.nn.Embedding(
            120,  # 120 different op-codes
            op_embedding_dim,
        )
        assert len(hidden_channels) > 0
        in_channels = op_embedding_dim + 140
        self.convs = torch.nn.ModuleList()
        last_dim = hidden_channels[0]

        # Create a sequence of Graph Convolutional Network (GCN) layers
        self.convs.append(GCNConv(in_channels, hidden_channels[0]))
        for i in range(len(hidden_channels) - 1):
            self.convs += [
                GATConv(hidden_channels[i], hidden_channels[i + 1]),
            ]
            last_dim = hidden_channels[i + 1]
        self.convs.append(GATConv(last_dim, graph_out))

        # Transformer encoder to merge configs features with graph features
        layer = nn.TransformerEncoderLayer(
            d_model=merged_node_dim,
            dim_feedforward=num_feedforward,
            nhead=nhead,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_encoder)

        self.layernorm = nn.LayerNorm(merged_node_dim)

        self.config_prj = nn.Sequential(
            nn.Linear(18, config_dim),
            nn.LayerNorm(config_dim),
        )

        # Define a sequential dense neural network
        self.dense = torch.nn.Sequential(
            nn.Linear(merged_node_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_ids: torch.Tensor,
    ) -> torch.Tensor:
        # Get graph features
        c = node_config_feat.size(0)

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)

        # Through convolutional layers
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        # x_graph = torch.mean(x, dim=0)  # (N, 64) -> (64,)
        # node_config_feat = torch.mean(node_config_feat, dim=1).squeeze(
        #     1
        # )  # (C, NC, 18) -> (C, 18)
        # Combine graph data with config data
        # x = torch.cat(
        #     [node_config_feat, x.repeat((len(node_config_feat), 1))],
        #     axis=1,
        # )
        x = x[node_config_ids]  # (N, 64) -> (NC, 64)
        # print("output of conv", x)
        
        node_config_feat = self.config_prj(node_config_feat)
        # print("output of config prj", node_config_feat)
        x = torch.cat([x.repeat((c, 1, 1)), node_config_feat], dim=-1)  # (C, NC, 128)
        # print("output of concat", x)

        # TODO: could be removed?
        x = self.layernorm(x)
        # print("output of layernorm", x)
        x = self.encoder(x)[:, -1, :]  # (C, NC, 128) -> (C, 128)

        # Pass the combined data through the dense neural network
        x = self.dense(x).flatten()
        # print("output of encoder after flatten", x)

        # Standardize the output
        # x = (x - torch.mean(x)) / (torch.std(x) + 1e-7)
        # print("output after norm", x)
        
        return x

## Training

In [6]:
# 🔄 Cross-Validation Training Loop (Enhanced)

layout_random = load_layout(
    LAYOUT_DIR,
    compile_type="random",
)
# TODO: use dataloader concating batch_size samples
train_val_dataset = LayoutDataset(
    layout_random["train"] + layout_random["valid"]
)

num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
warmup_ratio = configs["warmup_ratio"]
max_configs = configs["max_configs"]
graph_hidden = configs["graph_hidden"]
graph_out = configs["graph_out"]
num_encoder = configs["num_encoder"]
num_feedforward = configs["num_feedforward"]
nhead = configs["nhead"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]

model = SimpleModel(
    hidden_channels=graph_hidden,
    graph_out=graph_out,
    num_encoder=num_encoder,
    num_feedforward=num_feedforward,
    nhead=nhead,
).to(device)

# Create a K-Fold cross-validator with 5 splits
# kfold = KFold(
#     n_splits=5, shuffle=True, random_state=42
# )

# Iterate through each fold
# for fold, (tr_idx, va_idx) in enumerate(kfold.split(df)):

criterion = MultiElementRankLoss(margin=margin, number_permutations=number_permutations)
num_steps = (
    len(layout_random["train"]) * num_epochs
)
warmup_steps = int(num_steps * warmup_ratio)

optimizer = optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
scheduler = CosineAnnealingWarmupRestarts(
    optimizer=optimizer,
    first_cycle_steps=num_steps,
    min_lr=min_lr,
    max_lr=learning_rate,
    warmup_steps=warmup_steps,
)

run = wandb.init(
    project=WANDB_PROJECT,
    dir=WANDB_DIR,
    name=WANDB_RUN_NAME,
    config=configs,
    tags=TAGS,
)
run.watch(model, log="all")

best_score_mean = -1

# Training loop with increased epochs
for epoch in range(num_epochs):
    permutation = np.random.permutation(len(layout_random["train"]))
    train_layout = [layout_random["train"][i] for i in permutation]

    train_dataset = LayoutDataset(train_layout, max_configs=max_configs)
    val_dataset = LayoutDataset(layout_random["valid"], max_configs=max_configs)

    model.train()
    pbar = tqdm(range(len(train_dataset)), leave=False)
    loss_sum = 0
    n = 0

    for i in pbar:
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        ) = train_dataset[i]
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
            target.to(device),
        )

        optimizer.zero_grad()

        try:
            out = model(
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
            )
        except torch.cuda.OutOfMemoryError:
            print(f"Cuda out of memory at step {i} with node_feat.shape = {node_feat.shape}")
            torch.cuda.empty_cache()
            continue

        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)

        scheduler.step()
        optimizer.step()

        wandb.log(
            {
                "epoch": epoch,
                "train/lr": scheduler.get_lr()[0],
                "train/loss": loss.item(),
            }
        )

        pbar.set_description(f"epoch: {epoch} loss: {(loss.item()):.2f}")

    pbar.close()

    model.eval()
    layout_xla_scores = []
    pbar = tqdm(range(len(val_dataset)), leave=False)

    for i in pbar:
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        ) = val_dataset[i]
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
        )

        out: torch.Tensor = model(
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
        )

        score = metric_for_layout_collections(out.detach().cpu().numpy(), target.numpy())
        
        if score == np.nan:
            print(f"score is nan at step {i} with file {val_dataset.files[i]}")
            print(out)

        layout_xla_scores.append(score)

    pbar.close()

    score_mean = np.mean(layout_xla_scores)
    score_max = np.max(layout_xla_scores)

    wandb.log(
        {
            "val/kendalltau": score_mean,
        }
    )

    print(
        f"epoch {epoch}, max_score = {score_max:.4f}, mean_score = {score_mean:.4f},"
    )

    # Update best scores and save the model if the mean score improves
    if score_mean > best_score_mean:
        best_score_mean = score_mean
        best_score_max = score_max
        print(f"Best score updated: {best_score_mean:.4f}")
        torch.save(model.state_dict(), f"{epoch}_{score_mean:.4f}.pth")

run.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33medenn0[0m. Use [1m`wandb login --relogin`[0m to force relogin


                                                                      

epoch 0, max_score = 0.1653, mean_score = -0.0212,
Best score updated: -0.0212


                                                                      

epoch 1, max_score = 0.1619, mean_score = -0.0008,
Best score updated: -0.0008


                                                                      

epoch 2, max_score = 0.1273, mean_score = -0.0301,


                                                                      

epoch 3, max_score = 0.1192, mean_score = -0.0051,


                                                                      

epoch 4, max_score = 0.1143, mean_score = -0.0104,


                                                                      

epoch 5, max_score = 0.1224, mean_score = -0.0066,


                                                                      

epoch 6, max_score = 0.2251, mean_score = -0.0167,


                                                                      

epoch 7, max_score = 0.1224, mean_score = -0.0378,


                                                                      

epoch 8, max_score = 0.1438, mean_score = -0.0075,


                                                                      

epoch 9, max_score = 0.2515, mean_score = -0.0193,


                                                                       

epoch 10, max_score = 0.1240, mean_score = -0.0202,


                                                                       

epoch 11, max_score = 0.0792, mean_score = -0.0303,


                                                                       

epoch 12, max_score = 0.1200, mean_score = -0.0190,


                                                                       

epoch 13, max_score = 0.1647, mean_score = -0.0289,


                                                                       

epoch 14, max_score = 0.1788, mean_score = -0.0087,


                                                                       

epoch 15, max_score = 0.2457, mean_score = -0.0060,


                                                                       

epoch 16, max_score = 0.1892, mean_score = 0.0090,
Best score updated: 0.0090


                                                                       

epoch 17, max_score = 0.2133, mean_score = -0.0184,


                                                                       

epoch 18, max_score = 0.1566, mean_score = -0.0022,


                                                                       

epoch 19, max_score = 0.2150, mean_score = 0.0121,
Best score updated: 0.0121


                                                                       

epoch 20, max_score = 0.2479, mean_score = 0.0053,


                                                                       

epoch 21, max_score = 0.1891, mean_score = 0.0040,


                                                                       

epoch 22, max_score = 0.1764, mean_score = 0.0315,
Best score updated: 0.0315


                                                                       

epoch 23, max_score = 0.1370, mean_score = 0.0073,


                                                                       

epoch 24, max_score = 0.1879, mean_score = 0.0260,


                                                                       

epoch 25, max_score = 0.1802, mean_score = 0.0007,


                                                                       

epoch 26, max_score = 0.1414, mean_score = -0.0424,


                                                                       

epoch 27, max_score = 0.1838, mean_score = 0.0138,


                                                                       

epoch 28, max_score = 0.2331, mean_score = -0.0104,


                                                                       

epoch 29, max_score = 0.1640, mean_score = -0.0037,


                                                                       

epoch 30, max_score = 0.1180, mean_score = -0.0057,


                                                                       

epoch 31, max_score = 0.1095, mean_score = -0.0195,


                                                                       

epoch 32, max_score = 0.1984, mean_score = 0.0184,


                                                                       

epoch 33, max_score = 0.2283, mean_score = -0.0079,


                                                                       

epoch 34, max_score = 0.2040, mean_score = -0.0199,


                                                                       

epoch 35, max_score = 0.1265, mean_score = 0.0019,


                                                                       

epoch 36, max_score = 0.1910, mean_score = 0.0236,


                                                                       

epoch 37, max_score = 0.1966, mean_score = -0.0060,


                                                                       

epoch 38, max_score = 0.1333, mean_score = -0.0152,


                                                                       

epoch 39, max_score = 0.2546, mean_score = 0.0091,


                                                                       

epoch 40, max_score = 0.1412, mean_score = -0.0039,


                                                                       

epoch 41, max_score = 0.0945, mean_score = -0.0220,


                                                                       

epoch 42, max_score = 0.1324, mean_score = -0.0006,


                                                                       

epoch 43, max_score = 0.2513, mean_score = 0.0127,


                                                                       

epoch 44, max_score = 0.1584, mean_score = 0.0195,


                                                                       

epoch 45, max_score = 0.1299, mean_score = 0.0121,


                                                                       

epoch 46, max_score = 0.1774, mean_score = -0.0153,


                                                                       

epoch 47, max_score = 0.1495, mean_score = -0.0042,


                                                                       

epoch 48, max_score = 0.2840, mean_score = -0.0003,


                                                                       

epoch 49, max_score = 0.1590, mean_score = -0.0048,


                                                                       

epoch 50, max_score = 0.0828, mean_score = -0.0493,


                                                                       

epoch 51, max_score = 0.2085, mean_score = -0.0061,


                                                                       

epoch 52, max_score = 0.1487, mean_score = -0.0054,


                                                                       

epoch 53, max_score = 0.1228, mean_score = 0.0060,


                                                                       

epoch 54, max_score = 0.1653, mean_score = -0.0152,


                                                                       

epoch 55, max_score = 0.1697, mean_score = 0.0006,


                                                                       

epoch 56, max_score = 0.1285, mean_score = -0.0232,


                                                                       

epoch 57, max_score = 0.1564, mean_score = -0.0181,


                                                                       

epoch 58, max_score = 0.1285, mean_score = 0.0049,


                                                                       

epoch 59, max_score = 0.2792, mean_score = -0.0030,


                                                                       

epoch 60, max_score = 0.0994, mean_score = -0.0301,


                                                                       

epoch 61, max_score = 0.1529, mean_score = -0.0037,


                                                                       

epoch 62, max_score = 0.2964, mean_score = 0.0098,


                                                                       

epoch 63, max_score = 0.1485, mean_score = 0.0159,


                                                                       

epoch 64, max_score = 0.1774, mean_score = 0.0151,


                                                                       

epoch 65, max_score = 0.2096, mean_score = 0.0381,
Best score updated: 0.0381


                                                                       

epoch 66, max_score = 0.1978, mean_score = 0.0089,


                                                                       

epoch 67, max_score = 0.1297, mean_score = -0.0095,


                                                                       

epoch 68, max_score = 0.2384, mean_score = 0.0035,


                                                                       

epoch 69, max_score = 0.1810, mean_score = -0.0124,


                                                                       

epoch 70, max_score = 0.1309, mean_score = 0.0159,


                                                                       

epoch 71, max_score = 0.1564, mean_score = 0.0009,


                                                                       

epoch 72, max_score = 0.1693, mean_score = -0.0110,


                                                                       

epoch 73, max_score = 0.1507, mean_score = -0.0200,


                                                                       

epoch 74, max_score = 0.1034, mean_score = -0.0172,


                                                                       

epoch 75, max_score = 0.2335, mean_score = 0.0072,


                                                                       

epoch 76, max_score = 0.1315, mean_score = -0.0088,


                                                                       

epoch 77, max_score = 0.2089, mean_score = -0.0184,


                                                                       

epoch 78, max_score = 0.1459, mean_score = -0.0229,


                                                                       

epoch 79, max_score = 0.0954, mean_score = -0.0164,


                                                                       

epoch 80, max_score = 0.2198, mean_score = -0.0029,


                                                                       

epoch 81, max_score = 0.1592, mean_score = -0.0006,


                                                                       

epoch 82, max_score = 0.2206, mean_score = 0.0070,


                                                                       

epoch 83, max_score = 0.1895, mean_score = 0.0023,


                                                                       

epoch 84, max_score = 0.1818, mean_score = -0.0325,


                                                                       

epoch 85, max_score = 0.1899, mean_score = -0.0383,


                                                                       

epoch 86, max_score = 0.2239, mean_score = 0.0019,


                                                                       

epoch 87, max_score = 0.1996, mean_score = -0.0141,


                                                                       

epoch 88, max_score = 0.1446, mean_score = 0.0053,


                                                                       

epoch 89, max_score = 0.1640, mean_score = -0.0209,


                                                                       

epoch 90, max_score = 0.1867, mean_score = 0.0183,


                                                                       

epoch 91, max_score = 0.0836, mean_score = -0.0184,


                                                                       

epoch 92, max_score = 0.2545, mean_score = 0.0155,


                                                                       

epoch 93, max_score = 0.0743, mean_score = -0.0263,


                                                                       

epoch 94, max_score = 0.1562, mean_score = -0.0228,


                                                                       

epoch 95, max_score = 0.1366, mean_score = -0.0233,


                                                                       

epoch 96, max_score = 0.2461, mean_score = 0.0210,


                                                                       

epoch 97, max_score = 0.1119, mean_score = -0.0261,


                                                                       

epoch 98, max_score = 0.1515, mean_score = 0.0120,


                                                                       

epoch 99, max_score = 0.1556, mean_score = -0.0372,


                                                                        

epoch 100, max_score = 0.1212, mean_score = -0.0253,


                                                                        

epoch 101, max_score = 0.1006, mean_score = -0.0398,


                                                                        

epoch 102, max_score = 0.1160, mean_score = -0.0181,


                                                                        

epoch 103, max_score = 0.1608, mean_score = 0.0039,


                                                                        

epoch 104, max_score = 0.1935, mean_score = -0.0139,


                                                                        

epoch 105, max_score = 0.2036, mean_score = 0.0226,


                                                                        

epoch 106, max_score = 0.1685, mean_score = -0.0103,


                                                                        

epoch 107, max_score = 0.1677, mean_score = 0.0064,


                                                                        

epoch 108, max_score = 0.2299, mean_score = -0.0151,


                                                                        

epoch 109, max_score = 0.1915, mean_score = 0.0166,


                                                                        

epoch 110, max_score = 0.1895, mean_score = -0.0052,


                                                                        

epoch 111, max_score = 0.1348, mean_score = -0.0025,


                                                                        

epoch 112, max_score = 0.1714, mean_score = -0.0112,


                                                                          

epoch 113, max_score = 0.2509, mean_score = -0.0114,


                                                                        

epoch 114, max_score = 0.1851, mean_score = 0.0165,


                                                                        

epoch 115, max_score = 0.1903, mean_score = 0.0063,


                                                                        

epoch 116, max_score = 0.2114, mean_score = 0.0057,


                                                                        

epoch 117, max_score = 0.1127, mean_score = -0.0118,


                                                                        

epoch 118, max_score = 0.1709, mean_score = -0.0143,


                                                                        

epoch 119, max_score = 0.1935, mean_score = -0.0103,


                                                                        

epoch 120, max_score = 0.2109, mean_score = 0.0036,


                                                                        

epoch 121, max_score = 0.1604, mean_score = -0.0045,


                                                                        

epoch 122, max_score = 0.1190, mean_score = -0.0264,


                                                                        

epoch 123, max_score = 0.2113, mean_score = -0.0104,


                                                                        

epoch 124, max_score = 0.2089, mean_score = -0.0061,


                                                                        

epoch 125, max_score = 0.1455, mean_score = -0.0180,


                                                                        

epoch 126, max_score = 0.2194, mean_score = -0.0077,


                                                                        

epoch 127, max_score = 0.0901, mean_score = -0.0349,


                                                                        

epoch 128, max_score = 0.1418, mean_score = -0.0183,


                                                                        

epoch 129, max_score = 0.1873, mean_score = -0.0202,


                                                                        

epoch 130, max_score = 0.2093, mean_score = 0.0104,


                                                                        

epoch 131, max_score = 0.1612, mean_score = -0.0152,


                                                                        

epoch 132, max_score = 0.1160, mean_score = -0.0215,


                                                                        

epoch 133, max_score = 0.1604, mean_score = 0.0268,


                                                                        

epoch 134, max_score = 0.1527, mean_score = -0.0026,


                                                                        

epoch 135, max_score = 0.1802, mean_score = -0.0093,


                                                                        

epoch 136, max_score = 0.1107, mean_score = -0.0135,


                                                                        

epoch 137, max_score = 0.2776, mean_score = -0.0031,


                                                                        

epoch 138, max_score = 0.1285, mean_score = 0.0007,


                                                                        

epoch 139, max_score = 0.2314, mean_score = -0.0056,


                                                                        

epoch 140, max_score = 0.1931, mean_score = -0.0006,


                                                                        

epoch 141, max_score = 0.1224, mean_score = -0.0107,


                                                                        

epoch 142, max_score = 0.1948, mean_score = 0.0084,


                                                                        

epoch 143, max_score = 0.1204, mean_score = -0.0065,


                                                                        

epoch 144, max_score = 0.1266, mean_score = -0.0089,


                                                                        

epoch 145, max_score = 0.1455, mean_score = -0.0062,


                                                                        

epoch 146, max_score = 0.1239, mean_score = -0.0243,


                                                                        

epoch 147, max_score = 0.2200, mean_score = 0.0037,


                                                                        

epoch 148, max_score = 0.1935, mean_score = 0.0039,


                                                                        

epoch 149, max_score = 0.1139, mean_score = -0.0152,


                                                                        

epoch 150, max_score = 0.1289, mean_score = -0.0131,


                                                                        

epoch 151, max_score = 0.1487, mean_score = -0.0198,


                                                                        

epoch 152, max_score = 0.1903, mean_score = 0.0055,


                                                                        

epoch 153, max_score = 0.2311, mean_score = -0.0190,


                                                                        

epoch 154, max_score = 0.1523, mean_score = -0.0210,


                                                                        

epoch 155, max_score = 0.1390, mean_score = -0.0237,


                                                                        

epoch 156, max_score = 0.2376, mean_score = -0.0204,


                                                                        

epoch 157, max_score = 0.1057, mean_score = -0.0168,


                                                                        

epoch 158, max_score = 0.2432, mean_score = -0.0060,


                                                                        

epoch 159, max_score = 0.1198, mean_score = -0.0149,


                                                                        

epoch 160, max_score = 0.1406, mean_score = -0.0183,


                                                                        

epoch 161, max_score = 0.1974, mean_score = -0.0047,


                                                                        

epoch 162, max_score = 0.1479, mean_score = -0.0195,


                                                                        

epoch 163, max_score = 0.1806, mean_score = -0.0054,


                                                                        

epoch 164, max_score = 0.1691, mean_score = 0.0036,


                                                                        

epoch 165, max_score = 0.1612, mean_score = 0.0007,


                                                                        

epoch 166, max_score = 0.1950, mean_score = -0.0241,


                                                                        

epoch 167, max_score = 0.2500, mean_score = -0.0267,


                                                                        

epoch 168, max_score = 0.1697, mean_score = -0.0101,


                                                                        

epoch 169, max_score = 0.1123, mean_score = -0.0119,


                                                                        

epoch 170, max_score = 0.1929, mean_score = -0.0291,


                                                                        

epoch 171, max_score = 0.2261, mean_score = -0.0136,


                                                                        

epoch 172, max_score = 0.1560, mean_score = -0.0070,


                                                                        

epoch 173, max_score = 0.2032, mean_score = -0.0077,


                                                                        

epoch 174, max_score = 0.1931, mean_score = -0.0167,


                                                                        

epoch 175, max_score = 0.1619, mean_score = -0.0244,


                                                                        

epoch 176, max_score = 0.1382, mean_score = -0.0274,


                                                                        

epoch 177, max_score = 0.1867, mean_score = -0.0257,


                                                                        

epoch 178, max_score = 0.1398, mean_score = -0.0151,


                                                                        

epoch 179, max_score = 0.2630, mean_score = -0.0071,


                                                                        

epoch 180, max_score = 0.1325, mean_score = -0.0060,


                                                                        

epoch 181, max_score = 0.1160, mean_score = -0.0209,


                                                                        

epoch 182, max_score = 0.1639, mean_score = -0.0078,


                                                                        

epoch 183, max_score = 0.1736, mean_score = -0.0119,


                                                                        

epoch 184, max_score = 0.1547, mean_score = -0.0078,


                                                                        

epoch 185, max_score = 0.1507, mean_score = -0.0068,


                                                                        

epoch 186, max_score = 0.1588, mean_score = -0.0157,


                                                                        

epoch 187, max_score = 0.2040, mean_score = -0.0057,


                                                                        

epoch 188, max_score = 0.1192, mean_score = -0.0262,


                                                                        

epoch 189, max_score = 0.1806, mean_score = -0.0149,


                                                                        

epoch 190, max_score = 0.2127, mean_score = 0.0134,


                                                                        

epoch 191, max_score = 0.1681, mean_score = 0.0111,


                                                                        

epoch 192, max_score = 0.2065, mean_score = -0.0163,


                                                                        

epoch 193, max_score = 0.1281, mean_score = -0.0147,


                                                                        

epoch 194, max_score = 0.1895, mean_score = -0.0195,


                                                                        

epoch 195, max_score = 0.1560, mean_score = 0.0046,


                                                                        

epoch 196, max_score = 0.1952, mean_score = -0.0251,


                                                                        

epoch 197, max_score = 0.2095, mean_score = -0.0067,


                                                                        

epoch 198, max_score = 0.2315, mean_score = -0.0288,


                                                                        

epoch 199, max_score = 0.2137, mean_score = -0.0462,


                                                                        

epoch 200, max_score = 0.1451, mean_score = -0.0154,


                                                                        

epoch 201, max_score = 0.2396, mean_score = 0.0071,


                                                                        

epoch 202, max_score = 0.2485, mean_score = -0.0264,


                                                                        

epoch 203, max_score = 0.1236, mean_score = -0.0254,


                                                                        

epoch 204, max_score = 0.1677, mean_score = -0.0111,


                                                                        

epoch 205, max_score = 0.1867, mean_score = -0.0153,


                                                                        

epoch 206, max_score = 0.1248, mean_score = -0.0122,


                                                                        

epoch 207, max_score = 0.2261, mean_score = -0.0253,


                                                                        

epoch 208, max_score = 0.1253, mean_score = -0.0029,


                                                                        

epoch 209, max_score = 0.1449, mean_score = -0.0300,


                                                                        

epoch 210, max_score = 0.1220, mean_score = 0.0034,


                                                                        

epoch 211, max_score = 0.1697, mean_score = -0.0031,


                                                                        

epoch 212, max_score = 0.1087, mean_score = -0.0054,


                                                                        

epoch 213, max_score = 0.1568, mean_score = 0.0133,


                                                                        

epoch 214, max_score = 0.1822, mean_score = -0.0309,


                                                                        

epoch 215, max_score = 0.2077, mean_score = -0.0011,


                                                                        

epoch 216, max_score = 0.1200, mean_score = -0.0190,


                                                                        

epoch 217, max_score = 0.1491, mean_score = -0.0085,


                                                                        

epoch 218, max_score = 0.1640, mean_score = 0.0037,


                                                                        

epoch 219, max_score = 0.1430, mean_score = -0.0142,


                                                                        

epoch 220, max_score = 0.1628, mean_score = -0.0059,


                                                                        

epoch 221, max_score = 0.0954, mean_score = -0.0065,


                                                                        

epoch 222, max_score = 0.1034, mean_score = -0.0246,


                                                                        

epoch 223, max_score = 0.2101, mean_score = 0.0162,


                                                                        

epoch 224, max_score = 0.1087, mean_score = -0.0094,


                                                                        

epoch 225, max_score = 0.1434, mean_score = -0.0160,


                                                                        

epoch 226, max_score = 0.1519, mean_score = -0.0029,


                                                                        

epoch 227, max_score = 0.1103, mean_score = -0.0177,


                                                                        

epoch 228, max_score = 0.1984, mean_score = -0.0047,


                                                                        

epoch 229, max_score = 0.1721, mean_score = 0.0127,


                                                                        

epoch 230, max_score = 0.1859, mean_score = -0.0013,


                                                                        

epoch 231, max_score = 0.0962, mean_score = -0.0106,


                                                                        

epoch 232, max_score = 0.1006, mean_score = -0.0177,


                                                                        

epoch 233, max_score = 0.2073, mean_score = -0.0058,


                                                                        

epoch 234, max_score = 0.1655, mean_score = -0.0116,


                                                                        

epoch 235, max_score = 0.1430, mean_score = -0.0093,


                                                                        

epoch 236, max_score = 0.2444, mean_score = 0.0237,


                                                                        

epoch 237, max_score = 0.1947, mean_score = 0.0190,


                                                                        

epoch 238, max_score = 0.1370, mean_score = 0.0015,


                                                                        

epoch 239, max_score = 0.1204, mean_score = -0.0172,


                                                                        

epoch 240, max_score = 0.1915, mean_score = -0.0054,


                                                                        

epoch 241, max_score = 0.1562, mean_score = 0.0084,


                                                                        

epoch 242, max_score = 0.2376, mean_score = 0.0289,


                                                                        

epoch 243, max_score = 0.2176, mean_score = 0.0420,
Best score updated: 0.0420


                                                                        

epoch 244, max_score = 0.1469, mean_score = -0.0159,


                                                                        

epoch 245, max_score = 0.2140, mean_score = 0.0213,


                                                                        

epoch 246, max_score = 0.2271, mean_score = 0.0255,


                                                                        

epoch 247, max_score = 0.2059, mean_score = 0.0319,


                                                                        

epoch 248, max_score = 0.1784, mean_score = 0.0143,


                                                                        

epoch 249, max_score = 0.1665, mean_score = -0.0346,


                                                                        

epoch 250, max_score = 0.0808, mean_score = -0.0295,


                                                                        

epoch 251, max_score = 0.1749, mean_score = -0.0361,


                                                                        

epoch 252, max_score = 0.1507, mean_score = 0.0014,


                                                                        

epoch 253, max_score = 0.2378, mean_score = -0.0072,


                                                                        

epoch 254, max_score = 0.1152, mean_score = -0.0226,


                                                                        

epoch 255, max_score = 0.2048, mean_score = -0.0202,


                                                                        

epoch 256, max_score = 0.1661, mean_score = -0.0093,


                                                                        

epoch 257, max_score = 0.1277, mean_score = -0.0295,


                                                                        

epoch 258, max_score = 0.0976, mean_score = -0.0265,


                                                                        

epoch 259, max_score = 0.2149, mean_score = -0.0038,


                                                                        

epoch 260, max_score = 0.1644, mean_score = -0.0169,


                                                                        

epoch 261, max_score = 0.1592, mean_score = -0.0339,


                                                                        

epoch 262, max_score = 0.1772, mean_score = -0.0230,


                                                                        

epoch 263, max_score = 0.1497, mean_score = -0.0051,


                                                                        

epoch 264, max_score = 0.1366, mean_score = -0.0063,


                                                                        

epoch 265, max_score = 0.1123, mean_score = -0.0301,


                                                                        

epoch 266, max_score = 0.1533, mean_score = -0.0121,


                                                                        

epoch 267, max_score = 0.1800, mean_score = -0.0204,


                                                                        

epoch 268, max_score = 0.0812, mean_score = -0.0160,


                                                                        

epoch 269, max_score = 0.2218, mean_score = 0.0074,


                                                                        

epoch 270, max_score = 0.1240, mean_score = 0.0052,


                                                                        

epoch 271, max_score = 0.1337, mean_score = -0.0080,


                                                                        

epoch 272, max_score = 0.1604, mean_score = -0.0222,


                                                                          

epoch 273, max_score = 0.1826, mean_score = -0.0021,


                                                                        

epoch 274, max_score = 0.1580, mean_score = -0.0223,


                                                                        

epoch 275, max_score = 0.1208, mean_score = -0.0136,


                                                                        

epoch 276, max_score = 0.1248, mean_score = 0.0226,


                                                                        

epoch 277, max_score = 0.1980, mean_score = 0.0034,


                                                                        

epoch 278, max_score = 0.1446, mean_score = -0.0112,


                                                                        

epoch 279, max_score = 0.1386, mean_score = 0.0012,


                                                                        

epoch 280, max_score = 0.1325, mean_score = -0.0387,


                                                                        

epoch 281, max_score = 0.1051, mean_score = -0.0257,


                                                                        

epoch 282, max_score = 0.1269, mean_score = -0.0099,


                                                                        

epoch 283, max_score = 0.1358, mean_score = -0.0384,


                                                                        

epoch 284, max_score = 0.2149, mean_score = 0.0076,


                                                                        

epoch 285, max_score = 0.1160, mean_score = -0.0264,


                                                                        

epoch 286, max_score = 0.1354, mean_score = -0.0171,


                                                                        

epoch 287, max_score = 0.1228, mean_score = -0.0082,


                                                                        

epoch 288, max_score = 0.1434, mean_score = 0.0028,


                                                                        

epoch 289, max_score = 0.1604, mean_score = -0.0104,


                                                                        

epoch 290, max_score = 0.1313, mean_score = -0.0127,


                                                                        

epoch 291, max_score = 0.1325, mean_score = 0.0033,


                                                                        

epoch 292, max_score = 0.1487, mean_score = -0.0072,


                                                                        

epoch 293, max_score = 0.1531, mean_score = -0.0034,


                                                                        

epoch 294, max_score = 0.1792, mean_score = -0.0115,


                                                                        

epoch 295, max_score = 0.1147, mean_score = -0.0243,


                                                                        

epoch 296, max_score = 0.1277, mean_score = 0.0090,


                                                                        

epoch 297, max_score = 0.1816, mean_score = -0.0024,


                                                                        

epoch 298, max_score = 0.1362, mean_score = -0.0099,


                                                                        

epoch 299, max_score = 0.1632, mean_score = 0.0094,


                                                                        

epoch 300, max_score = 0.1079, mean_score = -0.0264,


                                                                        

epoch 301, max_score = 0.1382, mean_score = -0.0113,


                                                                        

epoch 302, max_score = 0.1992, mean_score = -0.0051,


                                                                        

epoch 303, max_score = 0.1745, mean_score = -0.0117,


                                                                        

epoch 304, max_score = 0.1758, mean_score = -0.0125,


                                                                        

epoch 305, max_score = 0.1382, mean_score = -0.0108,


                                                                        

epoch 306, max_score = 0.1907, mean_score = -0.0071,


                                                                        

epoch 307, max_score = 0.1022, mean_score = -0.0196,


                                                                        

epoch 308, max_score = 0.1988, mean_score = -0.0258,


                                                                        

epoch 309, max_score = 0.1883, mean_score = -0.0121,


                                                                        

epoch 310, max_score = 0.1806, mean_score = -0.0154,


                                                                        

epoch 311, max_score = 0.1386, mean_score = -0.0017,


                                                                        

epoch 312, max_score = 0.1523, mean_score = -0.0190,


                                                                        

epoch 313, max_score = 0.1285, mean_score = -0.0197,


                                                                        

epoch 314, max_score = 0.0998, mean_score = -0.0182,


                                                                        

epoch 315, max_score = 0.2137, mean_score = -0.0118,


                                                                        

epoch 316, max_score = 0.1644, mean_score = 0.0024,


                                                                        

epoch 317, max_score = 0.1939, mean_score = -0.0040,


                                                                        

epoch 318, max_score = 0.1653, mean_score = -0.0171,


                                                                        

epoch 319, max_score = 0.1180, mean_score = -0.0055,


                                                                        

epoch 320, max_score = 0.1600, mean_score = -0.0237,


                                                                        

epoch 321, max_score = 0.1915, mean_score = 0.0143,


                                                                        

epoch 322, max_score = 0.1164, mean_score = -0.0154,


                                                                        

epoch 323, max_score = 0.1697, mean_score = 0.0018,


                                                                        

epoch 324, max_score = 0.1588, mean_score = 0.0038,


                                                                        

epoch 325, max_score = 0.2020, mean_score = -0.0280,


                                                                        

epoch 326, max_score = 0.1640, mean_score = 0.0077,


epoch: 327 loss: 0.10:  57%|█████▋    | 158/276 [02:13<01:47,  1.09it/s]