In [1]:
from pathlib import Path

from src.metric import cls_eval_funs, reg_eval_funs
from src.trainer import Trainer
from src.utils import load_pkl_data, set_random_seed

In [13]:
class Config:
    data_dir = '/home/jiawei/Desktop/github/DOFEN/tabular-benchmark/tabular_benchmark_data'
    data_id = '361060'
    n_epoch = 250
    batch_size = 256
    target_transform = False

In [73]:
from typing import Dict, List, Optional

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from src.models.base import BaseClassifier, BaseRegressor
from src.models.layers import ResidualLayer, ReGLU


class Reshape(nn.Module):
    def __init__(self, *args: int) -> None:
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.reshape(self.shape)


class FastGroupConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        self.fast_mode = kwargs.pop('fast_mode')
        nn.Conv1d.__init__(self, *args, **kwargs)
        if self.groups > self.fast_mode:
            self.weight = nn.Parameter(
                self.weight.reshape(
                    self.groups, self.out_channels // self.groups, self.in_channels // self.groups, 1
                ).permute(3, 0, 2, 1)
            )
            self.bias = nn.Parameter(
                self.bias.unsqueeze(0).unsqueeze(-1)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.groups > self.fast_mode:
            x = x.reshape(-1, self.groups, self.in_channels // self.groups, 1)
            return (x * self.weight).sum(2, keepdims=True).permute(0, 1, 3, 2).reshape(-1, self.out_channels, 1) + self.bias
        else:
            return self._conv_forward(x, self.weight, self.bias)


class ConditionGeneration(nn.Module):
    def __init__(
        self,
        category_column_count: List[int],
        n_cond: int = 128,
        categorical_optimized: bool = False,
        fast_mode: int = 64,
    ):
        super(ConditionGeneration, self).__init__()
        self.fast_mode = fast_mode
        self.categorical_optimized = categorical_optimized

        index_info = self.extract_feature_metadata(category_column_count)
        self.numerical_index = index_info['numerical_index']
        self.categorical_index = index_info['categorical_index']
        self.categorical_count = index_info['categorical_count']

        categorical_offset = torch.tensor([0] + np.cumsum(self.categorical_count).tolist()[:-1]).long()
        self.register_buffer('categorical_offset', categorical_offset)

        self.n_cond = n_cond
        self.phi_1 = self.get_phi_1()

    def extract_feature_metadata(self, category_column_count: List[int]) -> Dict[str, List[int]]:
        numerical_index = [i for i, count in enumerate(category_column_count) if count == -1]
        categorical_index = [i for i, count in enumerate(category_column_count) if count != -1]
        categorical_count = [count for count in category_column_count if count != -1]
        return {
            'numerical_index': numerical_index,
            'categorical_index': categorical_index,
            'categorical_count': categorical_count,
        }

    def get_phi_1(self) -> nn.ModuleDict:
        phi_1 = nn.ModuleDict()
        if len(self.numerical_index):
            phi_1['num'] = nn.Sequential(
                # input = (b, n_num_col)
                # output = (b, n_num_col, n_cond)
                Reshape(-1, len(self.numerical_index), 1),
                FastGroupConv1d(
                    len(self.numerical_index),
                    len(self.numerical_index) * self.n_cond,
                    kernel_size=1,
                    groups=len(self.numerical_index),
                    fast_mode=self.fast_mode
                ), # (b, n_num_col, 1) -> (b, n_num_col * n_cond, 1)
                nn.Sigmoid(),
                Reshape(-1, len(self.numerical_index), self.n_cond)
            )
        if len(self.categorical_index):
            phi_1['cat'] = nn.ModuleDict()
            phi_1['cat']['embedder'] = nn.Embedding(sum(self.categorical_count), self.n_cond)            
            phi_1['cat']['mapper'] = nn.Sequential(
                # input = (b, n_cat_col, n_cond)
                # output = (b, n_cat_col, n_cond)
                Reshape(-1, len(self.categorical_index) * self.n_cond, 1),
                nn.GroupNorm(len(self.categorical_index), len(self.categorical_index) * self.n_cond),
                FastGroupConv1d(len(self.categorical_index) * self.n_cond, len(self.categorical_index)*self.n_cond, kernel_size=1, groups=len(self.categorical_index)*self.n_cond if self.categorical_optimized else len(self.categorical_index), fast_mode=self.fast_mode),                
                nn.Sigmoid(),
                Reshape(-1, len(self.categorical_index), self.n_cond)
            )
        return phi_1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        M = []

        if len(self.numerical_index):
            num_x = x[:, self.numerical_index].float()
            num_sample_emb = self.phi_1['num'](num_x)
            M.append(num_sample_emb)

        if len(self.categorical_index):
            cat_x = x[:, self.categorical_index].long() + self.cat_offset
            cat_sample_emb = self.phi_1['cat']['mapper'](self.phi_1['cat']['embedder'](cat_x))
            M.append(cat_sample_emb)

        M = torch.cat(M, dim=1) # (b, n_col, n_cond)
        M = M.permute(0, 2, 1) # (b, n_cond, n_col)
        return M


class rODTConstruction(nn.Module):
    def __init__(self, n_cond: int, n_col: int, d: int) -> None:
        super().__init__()
        self.permutator = torch.rand(n_cond * n_col).argsort(-1)
        self.d = d

    def forward(self, M: torch.Tensor) -> torch.Tensor:
        b, _, _, embed_dim = M.shape
        return M.reshape(b, -1, embed_dim)[:, self.permutator, :].reshape(b, -1, self.d, embed_dim)


class rODTForestConstruction(nn.Module):
    def __init__(
        self,
        n_col: int,
        n_rodt: int,
        n_cond: int,
        n_estimator: int,
        n_head: int = 1,
        n_hidden: int = 128,
        n_forest: int = 100,
        dropout: float = 0.0,
        fast_mode: int = 64,
        device = torch.device('cuda')
    ) -> None:

        super().__init__()
        self.device = device
        self.n_estimator = n_estimator
        self.n_forest = n_forest
        self.n_rodt = n_rodt
        self.n_head = n_head
        self.n_hidden = n_hidden

        self.sample_without_replacement_eval = self.get_sample_without_replacement()

    def get_sample_without_replacement(self) -> torch.Tensor:
        return torch.rand(self.n_forest, self.n_rodt, device=self.device).argsort(-1)[:, :self.n_estimator]

    def forward(self, w, E) -> torch.Tensor:
        # w: (b, n_rodt, 1)
        # E: (b, n_rodt, n_hidden)
    
        sample_without_replacement = self.get_sample_without_replacement() if self.training else self.sample_without_replacement_eval

        w_prime = w[:, sample_without_replacement].softmax(-2) # (b, n_forest, n_rodt, 1)
        E_prime = E[:, sample_without_replacement].reshape(
            E.shape[0], self.n_forest, self.n_estimator, self.n_hidden
        ) # (b, n_forest, n_rodt, n_hidden)

        F = (w_prime * E_prime).sum(-2).reshape(
            E.shape[0], self.n_forest, self.n_hidden
        ) # (b, n_forest, n_hidden)
        return F


class rODTForestBagging(nn.Module):
    def __init__(self, n_hidden: int, dropout: float, n_class: int) -> None:
        super().__init__()
        self.phi_3 = nn.Sequential(
            nn.LayerNorm(n_hidden),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.LayerNorm(n_hidden),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, n_class)
        )

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        return self.phi_3(F) # (b, n_forest, n_class)


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scaling = self.head_dim ** -0.5
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.qw_proj = self.create_linear(embed_dim, embed_dim)
        self.kw_proj = self.create_linear(embed_dim, embed_dim)
        self.vw_proj = self.create_linear(embed_dim, embed_dim)
        self.ow_proj = self.create_linear(embed_dim, 1)

        self.qE_proj = self.create_linear(embed_dim, embed_dim)
        self.kE_proj = self.create_linear(embed_dim, embed_dim)
        self.vE_proj = self.create_linear(embed_dim, embed_dim)
        self.oE_proj = self.create_linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def create_linear(self, in_features: int, out_features: int):
        linear = nn.Linear(in_features, out_features)
        nn.init.xavier_uniform_(linear.weight, gain=1 / 2 ** 0.5)
        return linear

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        query: Tensor of shape (batch_size, tgt_len, embed_dim)
        key: Tensor of shape (batch_size, src_len, embed_dim)
        value: Tensor of shape (batch_size, src_len, embed_dim)
        attn_mask: Optional[Tensor] of shape (tgt_len, src_len) or (batch_size, tgt_len, src_len)
        """
        batch_size, tgt_len, _ = query.size()
        src_len = key.size(1)

        # Compute w
        qw = self.qw_proj(query)  # shape: (batch_size, tgt_len, embed_dim)
        kw = self.kw_proj(key)    # shape: (batch_size, src_len, embed_dim)
        vw = self.vw_proj(value)  # shape: (batch_size, src_len, 1)

        # Reshape into multihead format
        qw = qw.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, tgt_len, head_dim)
        kw = kw.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, head_dim)
        vw = vw.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, 1)

        attn_w = F.softmax(torch.matmul(qw, kw.transpose(-2, -1)) / self.scaling, dim=-1) # shape: (batch_size, num_heads, tgt_len, src_len)
        attn_w = self.dropout(attn_w) 
        attn_w_output = torch.matmul(attn_w, vw) # shape: (batch_size, num_heads, tgt_len, head_dim)

        # Reshape back to original dimensions
        attn_w_output = attn_w_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim) # shape: (batch_size, tgt_len, embed_dim)
        w_output = self.ow_proj(attn_w_output) # shape: (batch_size, tgt_len, 1)
        w_output = w_output.mean(1)

        # Compute w
        qE = self.qE_proj(query)  # shape: (batch_size, tgt_len, embed_dim)
        kE = self.kE_proj(key)    # shape: (batch_size, src_len, embed_dim)
        vE = self.vE_proj(value)  # shape: (batch_size, src_len, embed_dim)

        # Reshape into multihead format
        qE = qE.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, tgt_len, head_dim)
        kE = kE.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, head_dim)
        vE = vE.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, 1)

        attn_E = F.softmax(torch.matmul(qE, kE.transpose(-2, -1)) / self.scaling, dim=-1) # shape: (batch_size, num_heads, tgt_len, src_len)
        attn_E = self.dropout(attn_E) 
        attn_E_output = torch.matmul(attn_E, vE) # shape: (batch_size, num_heads, tgt_len, head_dim)

        # Reshape back to original dimensions
        attn_E_output = attn_E_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim) # shape: (batch_size, tgt_len, embed_dim)
        E_output = self.oE_proj(attn_E_output) # shape: (batch_size, tgt_len, embed_dim)
        E_output = E_output.mean(1)

        return w_output, E_output


class DOFEN(nn.Module):
    def __init__(
        self,
        category_column_count: List[int],
        n_class: int, 
        m: int = 16, 
        d: int = 4, 
        n_head: int = 1,
        n_forest: int = 100,
        n_hidden: int = 128,
        dropout: float = 0.0, 
        categorical_optimized: bool = False,
        fast_mode: int = 2048,
        use_bagging_loss: bool = False,
        device=torch.device('cuda'),
    ):
        super().__init__()

        self.device = device
        self.n_class = 1 if n_class == -1 else n_class
        self.is_rgr = True if n_class == -1 else False

        self.m = m
        self.d = d
        self.n_head = n_head
        self.n_forest = n_forest
        self.n_hidden = n_hidden
        self.dropout = dropout
        self.use_bagging_loss = use_bagging_loss

        self.n_cond = self.d * self.m
        self.n_col = len(category_column_count)
        self.n_rodt = self.n_cond * self.n_col // self.d
        self.n_estimator = max(2, int(self.n_col ** 0.5)) * self.n_cond // self.d

        self.condition_generation = ConditionGeneration(            
            category_column_count, 
            n_cond=self.n_cond, 
            categorical_optimized=categorical_optimized, 
            fast_mode=fast_mode,
        )

        self.proj = nn.Linear(1, self.n_hidden)
        self.rodt_construction = rODTConstruction(
            self.n_cond,
            self.n_col,
            self.d,
        )

        self.norm = nn.LayerNorm(n_hidden)
        self.attn = MultiheadAttention(
            embed_dim=self.n_hidden,
            num_heads=n_head,
            dropout=0.2,
        )

        self.rodt_forest_construction = rODTForestConstruction(
            self.n_col, 
            self.n_rodt, 
            self.n_cond,
            self.n_estimator,
            n_head=self.n_head, 
            n_hidden=self.n_hidden,
            n_forest=self.n_forest,
            dropout=self.dropout,
            fast_mode=fast_mode,
            device=self.device
        )
        self.rodt_forest_bagging = rODTForestBagging(
            self.n_hidden,
            self.dropout,
            self.n_class
        )

        self.ffn_dropout_rate = 0.1
        self.act = ReGLU()

    def compute_loss(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def evaluate(self, X: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
        M = self.condition_generation(x)           # (b, n_cond, n_col)
        M = M.unsqueeze(-1)                        # (b, n_cond, n_col, 1)
        M = self.proj(M)                           # (b, n_cond, n_col, n_hidden)
        O = self.rodt_construction(M)              # (b, n_rodt, d, n_hidden)
        O = O.reshape(-1, self.d, self.n_hidden)   # (b * n_rodt, d, n_hidden)
        O = self.norm(O)
        w, E = self.attn(O, O, O)

        w = w.reshape(-1, self.n_rodt, 1)
        E = E.reshape(-1, self.n_rodt, self.n_hidden)

        E_norm = E / E.norm(dim=-1, keepdim=True)
        sim_loss = torch.einsum('bij,bjk->bik', E_norm, E_norm.permute(0, 2, 1))

        F = self.rodt_forest_construction(w, E) # (b, n_forest, n_hidden)
        y_hats = self.rodt_forest_bagging(F)    # (b, n_rodt, n_class)
        y_hat = y_hats.mean(1)                  # (b, n_class)

        if y is not None:
            loss = self.compute_loss(
                y_hats.permute(0, 2, 1) if not self.is_rgr else y_hats, 
                y.unsqueeze(-1).expand(-1, self.n_forest)
            ) +  sim_loss.abs().mean()
            if self.n_forest > 1 and self.training and self.use_bagging_loss:
                loss += self.compute_loss(y_hat, y)
            return {'pred': y_hat, 'loss': loss}
        return {'pred': y_hat}

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)['pred']


class DOFENClassifier(BaseClassifier, DOFEN):
    def __init__(
        self,
        category_column_count: List[int],
        n_class: int, 
        m: int = 16,
        d: int = 4,
        n_head: int = 4,
        n_forest: int = 100,
        n_hidden: int = 16,
        dropout: float = 0.0, 
        categorical_optimized: bool = False,
        fast_mode: int = 2048,
        use_bagging_loss: bool = False,
        device=torch.device('cuda'),
    ) -> None:

        super().__init__(
            category_column_count=category_column_count, 
            n_class=n_class, 
            m=m, 
            d=d, 
            n_head=n_head,
            n_forest=n_forest,
            n_hidden=n_hidden,
            dropout=dropout, 
            categorical_optimized=categorical_optimized,
            fast_mode=fast_mode,
            use_bagging_loss=use_bagging_loss,
            device=device,
        )

In [74]:
args = Config()
set_random_seed()

data_dict = load_pkl_data(Path(args.data_dir, args.data_id, '0.pkl'))
n_class = data_dict['label_cat_count']
task = 'r' if n_class == -1 else 'c'

target_transform = args.target_transform if task == 'r' else False
train_X = data_dict['x_train']
train_y = data_dict['y_train' if not target_transform else 'y_train_transform']
valid_X = data_dict['x_val']
valid_y = data_dict['y_val' if not target_transform else 'y_val_transform']
test_X = data_dict['x_test']
test_y = data_dict['y_test' if not target_transform else 'y_test_transform']

data_args = {
    'n_feature': train_X.shape[1],
    'n_train': train_X.shape[0],
    'n_valid': valid_X.shape[0],
    'n_test': test_X.shape[0],
}

model_params = {
    'category_column_count': data_dict['col_cat_count'],
}
if n_class != -1:
    model_params['n_class'] = n_class


model = DOFENClassifier(**model_params)
eval_funs = cls_eval_funs
metrics = 'accuracy'

wandb = None
trainer = Trainer(
    model,
    batch_size=args.batch_size,
    n_epoch=args.n_epoch,
    eval_funs=eval_funs,
    metric=metrics,
    logger=wandb,
)
trainer.fit(
    train_X=train_X, train_y=train_y,
    valid_X=valid_X, valid_y=valid_y,
    test_X=test_X, test_y=test_y,
)

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 1/250 [00:00<03:03,  1.36it/s]

{'epoch': 1, 'valid_accuracy': 0.501053617419808, 'best_valid_epoch': 1, 'best_valid_score': 0.501053617419808, 'test_accuracy': 0.503712622917921, 'selected_test_score': 0.503712622917921}


  1%|          | 2/250 [00:01<03:01,  1.37it/s]

{'epoch': 2, 'valid_accuracy': 0.49929758838679467, 'best_valid_epoch': 1, 'best_valid_score': 0.501053617419808, 'test_accuracy': 0.5011539233393538, 'selected_test_score': 0.503712622917921}


  1%|          | 3/250 [00:02<03:00,  1.37it/s]

{'epoch': 3, 'valid_accuracy': 0.5204870053851557, 'best_valid_epoch': 3, 'best_valid_score': 0.5204870053851557, 'test_accuracy': 0.5209211318482841, 'selected_test_score': 0.5209211318482841}


  2%|▏         | 4/250 [00:02<02:59,  1.37it/s]

{'epoch': 4, 'valid_accuracy': 0.49906345118239287, 'best_valid_epoch': 3, 'best_valid_score': 0.5204870053851557, 'test_accuracy': 0.5009532410194661, 'selected_test_score': 0.5209211318482841}


  2%|▏         | 5/250 [00:03<02:58,  1.37it/s]

{'epoch': 5, 'valid_accuracy': 0.5539686256146101, 'best_valid_epoch': 5, 'best_valid_score': 0.5539686256146101, 'test_accuracy': 0.5542343969496287, 'selected_test_score': 0.5542343969496287}


  2%|▏         | 6/250 [00:04<02:58,  1.37it/s]

{'epoch': 6, 'valid_accuracy': 0.5029267150550223, 'best_valid_epoch': 5, 'best_valid_score': 0.5539686256146101, 'test_accuracy': 0.5043648404575557, 'selected_test_score': 0.5542343969496287}


  3%|▎         | 7/250 [00:05<03:02,  1.33it/s]

{'epoch': 7, 'valid_accuracy': 0.556192929056427, 'best_valid_epoch': 7, 'best_valid_score': 0.556192929056427, 'test_accuracy': 0.556191049568533, 'selected_test_score': 0.556191049568533}


  3%|▎         | 8/250 [00:05<03:00,  1.34it/s]

{'epoch': 8, 'valid_accuracy': 0.5523296651837977, 'best_valid_epoch': 7, 'best_valid_score': 0.556192929056427, 'test_accuracy': 0.5469596628537026, 'selected_test_score': 0.556191049568533}


  4%|▎         | 9/250 [00:06<02:58,  1.35it/s]

{'epoch': 9, 'valid_accuracy': 0.5687192694919223, 'best_valid_epoch': 9, 'best_valid_score': 0.5687192694919223, 'test_accuracy': 0.562010836845274, 'selected_test_score': 0.562010836845274}


  4%|▍         | 10/250 [00:07<02:56,  1.36it/s]

{'epoch': 10, 'valid_accuracy': 0.5916647155232967, 'best_valid_epoch': 10, 'best_valid_score': 0.5916647155232967, 'test_accuracy': 0.5991872366044552, 'selected_test_score': 0.5991872366044552}


  4%|▍         | 11/250 [00:08<02:55,  1.36it/s]

{'epoch': 11, 'valid_accuracy': 0.6123858581128542, 'best_valid_epoch': 11, 'best_valid_score': 0.6123858581128542, 'test_accuracy': 0.6088199879590608, 'selected_test_score': 0.6088199879590608}


  5%|▍         | 12/250 [00:08<02:54,  1.36it/s]

{'epoch': 12, 'valid_accuracy': 0.5887380004682744, 'best_valid_epoch': 11, 'best_valid_score': 0.6123858581128542, 'test_accuracy': 0.5932671081677704, 'selected_test_score': 0.6088199879590608}


  5%|▌         | 13/250 [00:09<02:53,  1.37it/s]

{'epoch': 13, 'valid_accuracy': 0.6753687660969327, 'best_valid_epoch': 13, 'best_valid_score': 0.6753687660969327, 'test_accuracy': 0.6751956652618905, 'selected_test_score': 0.6751956652618905}


  6%|▌         | 14/250 [00:10<02:52,  1.37it/s]

{'epoch': 14, 'valid_accuracy': 0.6966752516974948, 'best_valid_epoch': 14, 'best_valid_score': 0.6966752516974948, 'test_accuracy': 0.696969696969697, 'selected_test_score': 0.696969696969697}


  6%|▌         | 15/250 [00:11<02:51,  1.37it/s]

{'epoch': 15, 'valid_accuracy': 0.6499648794193398, 'best_valid_epoch': 14, 'best_valid_score': 0.6966752516974948, 'test_accuracy': 0.6485049167168373, 'selected_test_score': 0.696969696969697}


  6%|▋         | 16/250 [00:11<02:51,  1.37it/s]

{'epoch': 16, 'valid_accuracy': 0.7012409271833294, 'best_valid_epoch': 16, 'best_valid_score': 0.7012409271833294, 'test_accuracy': 0.699076861328517, 'selected_test_score': 0.699076861328517}


  7%|▋         | 17/250 [00:12<02:50,  1.37it/s]

{'epoch': 17, 'valid_accuracy': 0.7019433387965348, 'best_valid_epoch': 17, 'best_valid_score': 0.7019433387965348, 'test_accuracy': 0.7018362432269717, 'selected_test_score': 0.7018362432269717}


  7%|▋         | 18/250 [00:13<02:49,  1.37it/s]

{'epoch': 18, 'valid_accuracy': 0.7163427768672442, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.7177904876580373, 'selected_test_score': 0.7177904876580373}


  8%|▊         | 19/250 [00:14<02:53,  1.33it/s]

{'epoch': 19, 'valid_accuracy': 0.6676422383516741, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.6682721252257676, 'selected_test_score': 0.7177904876580373}


  8%|▊         | 20/250 [00:14<02:51,  1.34it/s]

{'epoch': 20, 'valid_accuracy': 0.7132989932100211, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.7128236002408188, 'selected_test_score': 0.7177904876580373}


  8%|▊         | 21/250 [00:15<02:49,  1.35it/s]

{'epoch': 21, 'valid_accuracy': 0.6520721142589557, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.6505619104956853, 'selected_test_score': 0.7177904876580373}


  9%|▉         | 22/250 [00:16<02:47,  1.36it/s]

{'epoch': 22, 'valid_accuracy': 0.7128307188012175, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.7140276941601444, 'selected_test_score': 0.7177904876580373}


  9%|▉         | 23/250 [00:16<02:46,  1.36it/s]

{'epoch': 23, 'valid_accuracy': 0.6985483493327089, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.6958659442103151, 'selected_test_score': 0.7177904876580373}


 10%|▉         | 24/250 [00:17<02:45,  1.36it/s]

{'epoch': 24, 'valid_accuracy': 0.6799344415827675, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.68176801123821, 'selected_test_score': 0.7177904876580373}


 10%|█         | 25/250 [00:18<02:44,  1.37it/s]

{'epoch': 25, 'valid_accuracy': 0.7109576211660033, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.713726670680313, 'selected_test_score': 0.7177904876580373}


 10%|█         | 26/250 [00:19<02:43,  1.37it/s]

{'epoch': 26, 'valid_accuracy': 0.6271365019901662, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.6245233794902669, 'selected_test_score': 0.7177904876580373}


 11%|█         | 27/250 [00:19<02:42,  1.37it/s]

{'epoch': 27, 'valid_accuracy': 0.7088503863263873, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.7075055187637969, 'selected_test_score': 0.7177904876580373}


 11%|█         | 28/250 [00:20<02:42,  1.37it/s]

{'epoch': 28, 'valid_accuracy': 0.7024116132053383, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.6990266907485451, 'selected_test_score': 0.7177904876580373}


 12%|█▏        | 29/250 [00:21<02:41,  1.37it/s]

{'epoch': 29, 'valid_accuracy': 0.7067431514867712, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.7064017660044151, 'selected_test_score': 0.7177904876580373}


 12%|█▏        | 30/250 [00:22<02:40,  1.37it/s]

{'epoch': 30, 'valid_accuracy': 0.6912900959962538, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.688591210114389, 'selected_test_score': 0.7177904876580373}


 12%|█▏        | 31/250 [00:22<02:44,  1.33it/s]

{'epoch': 31, 'valid_accuracy': 0.7129477874034184, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.7157334938791893, 'selected_test_score': 0.7177904876580373}


 13%|█▎        | 32/250 [00:23<02:42,  1.34it/s]

{'epoch': 32, 'valid_accuracy': 0.6994848981503161, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.6973710616094722, 'selected_test_score': 0.7177904876580373}


 13%|█▎        | 33/250 [00:24<02:40,  1.35it/s]

{'epoch': 33, 'valid_accuracy': 0.7127136501990167, 'best_valid_epoch': 18, 'best_valid_score': 0.7163427768672442, 'test_accuracy': 0.7144792293798916, 'selected_test_score': 0.7177904876580373}


 14%|█▎        | 34/250 [00:25<02:39,  1.36it/s]

{'epoch': 34, 'valid_accuracy': 0.7257082650433154, 'best_valid_epoch': 34, 'best_valid_score': 0.7257082650433154, 'test_accuracy': 0.7233092514549468, 'selected_test_score': 0.7233092514549468}


 14%|█▍        | 35/250 [00:25<02:38,  1.36it/s]

{'epoch': 35, 'valid_accuracy': 0.7066260828845704, 'best_valid_epoch': 34, 'best_valid_score': 0.7257082650433154, 'test_accuracy': 0.7056492073048365, 'selected_test_score': 0.7233092514549468}


 14%|█▍        | 36/250 [00:26<02:36,  1.36it/s]

{'epoch': 36, 'valid_accuracy': 0.7116600327792086, 'best_valid_epoch': 34, 'best_valid_score': 0.7257082650433154, 'test_accuracy': 0.7103150712422236, 'selected_test_score': 0.7233092514549468}


 15%|█▍        | 37/250 [00:27<02:35,  1.37it/s]

{'epoch': 37, 'valid_accuracy': 0.6976118005151019, 'best_valid_epoch': 34, 'best_valid_score': 0.7257082650433154, 'test_accuracy': 0.6962171382701184, 'selected_test_score': 0.7233092514549468}


 15%|█▌        | 38/250 [00:27<02:35,  1.37it/s]

{'epoch': 38, 'valid_accuracy': 0.6985483493327089, 'best_valid_epoch': 34, 'best_valid_score': 0.7257082650433154, 'test_accuracy': 0.6956150913104555, 'selected_test_score': 0.7233092514549468}


 16%|█▌        | 39/250 [00:28<02:34,  1.37it/s]

{'epoch': 39, 'valid_accuracy': 0.7268789510653243, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.725316074653823, 'selected_test_score': 0.725316074653823}


 16%|█▌        | 40/250 [00:29<02:33,  1.37it/s]

{'epoch': 40, 'valid_accuracy': 0.7132989932100211, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.715833835039133, 'selected_test_score': 0.725316074653823}


 16%|█▋        | 41/250 [00:30<02:32,  1.37it/s]

{'epoch': 41, 'valid_accuracy': 0.6970264575040974, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.6946116797110174, 'selected_test_score': 0.725316074653823}


 17%|█▋        | 42/250 [00:30<02:31,  1.37it/s]

{'epoch': 42, 'valid_accuracy': 0.7159915710606415, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.7188942404174192, 'selected_test_score': 0.725316074653823}


 17%|█▋        | 43/250 [00:31<02:35,  1.33it/s]

{'epoch': 43, 'valid_accuracy': 0.6848513228752049, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.682771422837648, 'selected_test_score': 0.725316074653823}


 18%|█▊        | 44/250 [00:32<02:33,  1.34it/s]

{'epoch': 44, 'valid_accuracy': 0.7013579957855304, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.6994782259682922, 'selected_test_score': 0.725316074653823}


 18%|█▊        | 45/250 [00:33<02:31,  1.35it/s]

{'epoch': 45, 'valid_accuracy': 0.7224303441816905, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.7246638571141882, 'selected_test_score': 0.725316074653823}


 18%|█▊        | 46/250 [00:33<02:30,  1.36it/s]

{'epoch': 46, 'valid_accuracy': 0.6994848981503161, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.6973710616094722, 'selected_test_score': 0.725316074653823}


 19%|█▉        | 47/250 [00:34<02:29,  1.36it/s]

{'epoch': 47, 'valid_accuracy': 0.7248887848279092, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.726871362632952, 'selected_test_score': 0.725316074653823}


 19%|█▉        | 48/250 [00:35<02:28,  1.36it/s]

{'epoch': 48, 'valid_accuracy': 0.6712713650199017, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.669626730885009, 'selected_test_score': 0.725316074653823}


 20%|█▉        | 49/250 [00:36<02:27,  1.37it/s]

{'epoch': 49, 'valid_accuracy': 0.7137672676188246, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.7189945815773631, 'selected_test_score': 0.725316074653823}


 20%|██        | 50/250 [00:36<02:26,  1.37it/s]

{'epoch': 50, 'valid_accuracy': 0.7121283071880121, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.7149809351796107, 'selected_test_score': 0.725316074653823}


 20%|██        | 51/250 [00:37<02:25,  1.37it/s]

{'epoch': 51, 'valid_accuracy': 0.7140014048232264, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.7176399759181216, 'selected_test_score': 0.725316074653823}


 21%|██        | 52/250 [00:38<02:24,  1.37it/s]

{'epoch': 52, 'valid_accuracy': 0.7123624443924139, 'best_valid_epoch': 39, 'best_valid_score': 0.7268789510653243, 'test_accuracy': 0.7157836644591612, 'selected_test_score': 0.725316074653823}


 21%|██        | 53/250 [00:38<02:23,  1.37it/s]

{'epoch': 53, 'valid_accuracy': 0.7308592835401545, 'best_valid_epoch': 53, 'best_valid_score': 0.7308592835401545, 'test_accuracy': 0.7307344972907887, 'selected_test_score': 0.7307344972907887}


 22%|██▏       | 54/250 [00:39<02:23,  1.37it/s]

{'epoch': 54, 'valid_accuracy': 0.7247717162257082, 'best_valid_epoch': 53, 'best_valid_score': 0.7308592835401545, 'test_accuracy': 0.7273228978526992, 'selected_test_score': 0.7307344972907887}


 22%|██▏       | 55/250 [00:40<02:26,  1.33it/s]

{'epoch': 55, 'valid_accuracy': 0.7026457504097401, 'best_valid_epoch': 53, 'best_valid_score': 0.7308592835401545, 'test_accuracy': 0.7014850491671684, 'selected_test_score': 0.7307344972907887}


 22%|██▏       | 56/250 [00:41<02:24,  1.34it/s]

{'epoch': 56, 'valid_accuracy': 0.7307422149379537, 'best_valid_epoch': 53, 'best_valid_score': 0.7308592835401545, 'test_accuracy': 0.7304334738109572, 'selected_test_score': 0.7307344972907887}


 23%|██▎       | 57/250 [00:41<02:22,  1.35it/s]

{'epoch': 57, 'valid_accuracy': 0.7116600327792086, 'best_valid_epoch': 53, 'best_valid_score': 0.7308592835401545, 'test_accuracy': 0.7139775235801726, 'selected_test_score': 0.7307344972907887}


 23%|██▎       | 58/250 [00:42<02:21,  1.36it/s]

{'epoch': 58, 'valid_accuracy': 0.722898618590494, 'best_valid_epoch': 53, 'best_valid_score': 0.7308592835401545, 'test_accuracy': 0.7251153923339354, 'selected_test_score': 0.7307344972907887}


 24%|██▎       | 59/250 [00:43<02:20,  1.36it/s]

{'epoch': 59, 'valid_accuracy': 0.7310934207445563, 'best_valid_epoch': 59, 'best_valid_score': 0.7310934207445563, 'test_accuracy': 0.7314870559903672, 'selected_test_score': 0.7314870559903672}


 24%|██▍       | 60/250 [00:44<02:19,  1.36it/s]

{'epoch': 60, 'valid_accuracy': 0.7321470381643643, 'best_valid_epoch': 60, 'best_valid_score': 0.7321470381643643, 'test_accuracy': 0.732139273530002, 'selected_test_score': 0.732139273530002}


 24%|██▍       | 61/250 [00:44<02:18,  1.37it/s]

{'epoch': 61, 'valid_accuracy': 0.7122453757902131, 'best_valid_epoch': 60, 'best_valid_score': 0.7321470381643643, 'test_accuracy': 0.7117198474814369, 'selected_test_score': 0.732139273530002}


 25%|██▍       | 62/250 [00:45<02:17,  1.37it/s]

{'epoch': 62, 'valid_accuracy': 0.7261765394521189, 'best_valid_epoch': 60, 'best_valid_score': 0.7321470381643643, 'test_accuracy': 0.7286273329319687, 'selected_test_score': 0.732139273530002}


 25%|██▌       | 63/250 [00:46<02:16,  1.37it/s]

{'epoch': 63, 'valid_accuracy': 0.7326153125731679, 'best_valid_epoch': 63, 'best_valid_score': 0.7326153125731679, 'test_accuracy': 0.7332931968693558, 'selected_test_score': 0.7332931968693558}


 26%|██▌       | 64/250 [00:47<02:15,  1.37it/s]

{'epoch': 64, 'valid_accuracy': 0.7147038164364318, 'best_valid_epoch': 63, 'best_valid_score': 0.7326153125731679, 'test_accuracy': 0.7173891230182621, 'selected_test_score': 0.7332931968693558}


 26%|██▌       | 65/250 [00:47<02:15,  1.37it/s]

{'epoch': 65, 'valid_accuracy': 0.7156403652540388, 'best_valid_epoch': 63, 'best_valid_score': 0.7326153125731679, 'test_accuracy': 0.7192956050571945, 'selected_test_score': 0.7332931968693558}


 26%|██▋       | 66/250 [00:48<02:14,  1.37it/s]

{'epoch': 66, 'valid_accuracy': 0.7189182861156638, 'best_valid_epoch': 63, 'best_valid_score': 0.7326153125731679, 'test_accuracy': 0.7226068633353402, 'selected_test_score': 0.7332931968693558}


 27%|██▋       | 67/250 [00:49<02:17,  1.33it/s]

{'epoch': 67, 'valid_accuracy': 0.7344884102083821, 'best_valid_epoch': 67, 'best_valid_score': 0.7344884102083821, 'test_accuracy': 0.7352498494882601, 'selected_test_score': 0.7352498494882601}


 27%|██▋       | 68/250 [00:50<02:15,  1.34it/s]

{'epoch': 68, 'valid_accuracy': 0.7347225474127839, 'best_valid_epoch': 68, 'best_valid_score': 0.7347225474127839, 'test_accuracy': 0.735300020068232, 'selected_test_score': 0.735300020068232}


 28%|██▊       | 69/250 [00:50<02:13,  1.35it/s]

{'epoch': 69, 'valid_accuracy': 0.7253570592367127, 'best_valid_epoch': 68, 'best_valid_score': 0.7347225474127839, 'test_accuracy': 0.7285269917720248, 'selected_test_score': 0.735300020068232}


 28%|██▊       | 70/250 [00:51<02:12,  1.36it/s]

{'epoch': 70, 'valid_accuracy': 0.7344884102083821, 'best_valid_epoch': 68, 'best_valid_score': 0.7347225474127839, 'test_accuracy': 0.735300020068232, 'selected_test_score': 0.735300020068232}


 28%|██▊       | 71/250 [00:52<02:11,  1.36it/s]

{'epoch': 71, 'valid_accuracy': 0.7344884102083821, 'best_valid_epoch': 68, 'best_valid_score': 0.7347225474127839, 'test_accuracy': 0.7354003612281758, 'selected_test_score': 0.735300020068232}


 29%|██▉       | 72/250 [00:52<02:10,  1.36it/s]

{'epoch': 72, 'valid_accuracy': 0.7209084523530789, 'best_valid_epoch': 68, 'best_valid_score': 0.7347225474127839, 'test_accuracy': 0.7240116395745535, 'selected_test_score': 0.735300020068232}


 29%|██▉       | 73/250 [00:53<02:09,  1.37it/s]

{'epoch': 73, 'valid_accuracy': 0.7188012175134629, 'best_valid_epoch': 68, 'best_valid_score': 0.7347225474127839, 'test_accuracy': 0.7218543046357616, 'selected_test_score': 0.735300020068232}


 30%|██▉       | 74/250 [00:54<02:08,  1.37it/s]

{'epoch': 74, 'valid_accuracy': 0.7233668929992976, 'best_valid_epoch': 68, 'best_valid_score': 0.7347225474127839, 'test_accuracy': 0.7292795504716034, 'selected_test_score': 0.735300020068232}


 30%|███       | 75/250 [00:55<02:07,  1.37it/s]

{'epoch': 75, 'valid_accuracy': 0.7363615078435963, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7372566726871362, 'selected_test_score': 0.7372566726871362}


 30%|███       | 76/250 [00:55<02:07,  1.37it/s]

{'epoch': 76, 'valid_accuracy': 0.7341372044017794, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7348484848484849, 'selected_test_score': 0.7372566726871362}


 31%|███       | 77/250 [00:56<02:06,  1.37it/s]

{'epoch': 77, 'valid_accuracy': 0.7353078904237883, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7377082079068834, 'selected_test_score': 0.7372566726871362}


 31%|███       | 78/250 [00:57<02:09,  1.33it/s]

{'epoch': 78, 'valid_accuracy': 0.7280496370873332, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7317379088902267, 'selected_test_score': 0.7372566726871362}


 32%|███▏      | 79/250 [00:58<02:07,  1.34it/s]

{'epoch': 79, 'valid_accuracy': 0.7319129009599625, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7339454144089905, 'selected_test_score': 0.7372566726871362}


 32%|███▏      | 80/250 [00:58<02:05,  1.35it/s]

{'epoch': 80, 'valid_accuracy': 0.7316787637555607, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7332931968693558, 'selected_test_score': 0.7372566726871362}


 32%|███▏      | 81/250 [00:59<02:04,  1.36it/s]

{'epoch': 81, 'valid_accuracy': 0.7336689299929758, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7354505318081477, 'selected_test_score': 0.7372566726871362}


 33%|███▎      | 82/250 [01:00<02:03,  1.36it/s]

{'epoch': 82, 'valid_accuracy': 0.7260594708499181, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7302327914910697, 'selected_test_score': 0.7372566726871362}


 33%|███▎      | 83/250 [01:01<02:02,  1.36it/s]

{'epoch': 83, 'valid_accuracy': 0.7231327557948958, 'best_valid_epoch': 75, 'best_valid_score': 0.7363615078435963, 'test_accuracy': 0.7277242624924745, 'selected_test_score': 0.7372566726871362}


 34%|███▎      | 84/250 [01:01<02:01,  1.37it/s]

{'epoch': 84, 'valid_accuracy': 0.7381175368766097, 'best_valid_epoch': 84, 'best_valid_score': 0.7381175368766097, 'test_accuracy': 0.7376580373269115, 'selected_test_score': 0.7376580373269115}


 34%|███▍      | 85/250 [01:02<02:00,  1.37it/s]

{'epoch': 85, 'valid_accuracy': 0.7334347927885742, 'best_valid_epoch': 84, 'best_valid_score': 0.7381175368766097, 'test_accuracy': 0.7354003612281758, 'selected_test_score': 0.7376580373269115}


 34%|███▍      | 86/250 [01:03<01:59,  1.37it/s]

{'epoch': 86, 'valid_accuracy': 0.7391711542964177, 'best_valid_epoch': 86, 'best_valid_score': 0.7391711542964177, 'test_accuracy': 0.7389123018262092, 'selected_test_score': 0.7389123018262092}


 35%|███▍      | 87/250 [01:03<01:59,  1.37it/s]

{'epoch': 87, 'valid_accuracy': 0.7369468508546008, 'best_valid_epoch': 86, 'best_valid_score': 0.7391711542964177, 'test_accuracy': 0.7380092313867148, 'selected_test_score': 0.7389123018262092}


 35%|███▌      | 88/250 [01:04<01:58,  1.37it/s]

{'epoch': 88, 'valid_accuracy': 0.7372980566612035, 'best_valid_epoch': 86, 'best_valid_score': 0.7391711542964177, 'test_accuracy': 0.7385109371864339, 'selected_test_score': 0.7389123018262092}


 36%|███▌      | 89/250 [01:05<01:57,  1.37it/s]

{'epoch': 89, 'valid_accuracy': 0.7182158745024585, 'best_valid_epoch': 86, 'best_valid_score': 0.7391711542964177, 'test_accuracy': 0.7207505518763797, 'selected_test_score': 0.7389123018262092}


 36%|███▌      | 90/250 [01:06<02:00,  1.33it/s]

{'epoch': 90, 'valid_accuracy': 0.7397564973074221, 'best_valid_epoch': 90, 'best_valid_score': 0.7397564973074221, 'test_accuracy': 0.7400662251655629, 'selected_test_score': 0.7400662251655629}


 36%|███▋      | 91/250 [01:06<01:58,  1.34it/s]

{'epoch': 91, 'valid_accuracy': 0.7384687426832124, 'best_valid_epoch': 90, 'best_valid_score': 0.7397564973074221, 'test_accuracy': 0.7392133253060406, 'selected_test_score': 0.7400662251655629}


 37%|███▋      | 92/250 [01:07<01:56,  1.35it/s]

{'epoch': 92, 'valid_accuracy': 0.7257082650433154, 'best_valid_epoch': 90, 'best_valid_score': 0.7397564973074221, 'test_accuracy': 0.727975115392334, 'selected_test_score': 0.7400662251655629}


 37%|███▋      | 93/250 [01:08<01:55,  1.36it/s]

{'epoch': 93, 'valid_accuracy': 0.7392882228986186, 'best_valid_epoch': 90, 'best_valid_score': 0.7397564973074221, 'test_accuracy': 0.7413204896648605, 'selected_test_score': 0.7400662251655629}


 38%|███▊      | 94/250 [01:09<01:54,  1.36it/s]

{'epoch': 94, 'valid_accuracy': 0.7375321938656052, 'best_valid_epoch': 90, 'best_valid_score': 0.7397564973074221, 'test_accuracy': 0.7374573550070239, 'selected_test_score': 0.7400662251655629}


 38%|███▊      | 95/250 [01:09<01:53,  1.36it/s]

{'epoch': 95, 'valid_accuracy': 0.7402247717162257, 'best_valid_epoch': 95, 'best_valid_score': 0.7402247717162257, 'test_accuracy': 0.7413204896648605, 'selected_test_score': 0.7413204896648605}


 38%|███▊      | 96/250 [01:10<01:52,  1.37it/s]

{'epoch': 96, 'valid_accuracy': 0.7401077031140249, 'best_valid_epoch': 95, 'best_valid_score': 0.7402247717162257, 'test_accuracy': 0.7416215131446919, 'selected_test_score': 0.7413204896648605}


 39%|███▉      | 97/250 [01:11<01:51,  1.37it/s]

{'epoch': 97, 'valid_accuracy': 0.7402247717162257, 'best_valid_epoch': 95, 'best_valid_score': 0.7402247717162257, 'test_accuracy': 0.7419727072044953, 'selected_test_score': 0.7413204896648605}


 39%|███▉      | 98/250 [01:12<01:51,  1.37it/s]

{'epoch': 98, 'valid_accuracy': 0.7408101147272301, 'best_valid_epoch': 98, 'best_valid_score': 0.7408101147272301, 'test_accuracy': 0.7423740718442705, 'selected_test_score': 0.7423740718442705}


 40%|███▉      | 99/250 [01:12<01:50,  1.37it/s]

{'epoch': 99, 'valid_accuracy': 0.7376492624678062, 'best_valid_epoch': 98, 'best_valid_score': 0.7408101147272301, 'test_accuracy': 0.7388621312462372, 'selected_test_score': 0.7423740718442705}


 40%|████      | 100/250 [01:13<01:49,  1.37it/s]

{'epoch': 100, 'valid_accuracy': 0.7354249590259893, 'best_valid_epoch': 98, 'best_valid_score': 0.7408101147272301, 'test_accuracy': 0.7375576961669676, 'selected_test_score': 0.7423740718442705}


 40%|████      | 101/250 [01:14<01:48,  1.37it/s]

{'epoch': 101, 'valid_accuracy': 0.7402247717162257, 'best_valid_epoch': 98, 'best_valid_score': 0.7408101147272301, 'test_accuracy': 0.7418723660445514, 'selected_test_score': 0.7423740718442705}


 41%|████      | 102/250 [01:15<01:51,  1.33it/s]

{'epoch': 102, 'valid_accuracy': 0.7426832123624444, 'best_valid_epoch': 102, 'best_valid_score': 0.7426832123624444, 'test_accuracy': 0.7422737306843267, 'selected_test_score': 0.7422737306843267}


 41%|████      | 103/250 [01:15<01:49,  1.34it/s]

{'epoch': 103, 'valid_accuracy': 0.7403418403184266, 'best_valid_epoch': 102, 'best_valid_score': 0.7426832123624444, 'test_accuracy': 0.7420730483644391, 'selected_test_score': 0.7422737306843267}


 42%|████▏     | 104/250 [01:16<01:47,  1.35it/s]

{'epoch': 104, 'valid_accuracy': 0.7426832123624444, 'best_valid_epoch': 102, 'best_valid_score': 0.7426832123624444, 'test_accuracy': 0.7409692956050572, 'selected_test_score': 0.7422737306843267}


 42%|████▏     | 105/250 [01:17<01:46,  1.36it/s]

{'epoch': 105, 'valid_accuracy': 0.7425661437602435, 'best_valid_epoch': 102, 'best_valid_score': 0.7426832123624444, 'test_accuracy': 0.7408187838651414, 'selected_test_score': 0.7422737306843267}


 42%|████▏     | 106/250 [01:17<01:45,  1.36it/s]

{'epoch': 106, 'valid_accuracy': 0.7404589089206275, 'best_valid_epoch': 102, 'best_valid_score': 0.7426832123624444, 'test_accuracy': 0.7424242424242424, 'selected_test_score': 0.7422737306843267}


 43%|████▎     | 107/250 [01:18<01:44,  1.36it/s]

{'epoch': 107, 'valid_accuracy': 0.7445563099976586, 'best_valid_epoch': 107, 'best_valid_score': 0.7445563099976586, 'test_accuracy': 0.7434778246036524, 'selected_test_score': 0.7434778246036524}


 43%|████▎     | 108/250 [01:19<01:43,  1.37it/s]

{'epoch': 108, 'valid_accuracy': 0.7426832123624444, 'best_valid_epoch': 107, 'best_valid_score': 0.7445563099976586, 'test_accuracy': 0.7441802127232591, 'selected_test_score': 0.7434778246036524}


 44%|████▎     | 109/250 [01:20<01:43,  1.37it/s]

{'epoch': 109, 'valid_accuracy': 0.7423320065558418, 'best_valid_epoch': 107, 'best_valid_score': 0.7445563099976586, 'test_accuracy': 0.7439293598233996, 'selected_test_score': 0.7434778246036524}


 44%|████▍     | 110/250 [01:20<01:42,  1.37it/s]

{'epoch': 110, 'valid_accuracy': 0.7387028798876142, 'best_valid_epoch': 107, 'best_valid_score': 0.7445563099976586, 'test_accuracy': 0.7409191250250853, 'selected_test_score': 0.7434778246036524}


 44%|████▍     | 111/250 [01:21<01:41,  1.37it/s]

{'epoch': 111, 'valid_accuracy': 0.7423320065558418, 'best_valid_epoch': 107, 'best_valid_score': 0.7445563099976586, 'test_accuracy': 0.7447320891029501, 'selected_test_score': 0.7434778246036524}


 45%|████▍     | 112/250 [01:22<01:40,  1.37it/s]

{'epoch': 112, 'valid_accuracy': 0.7370639194568017, 'best_valid_epoch': 107, 'best_valid_score': 0.7445563099976586, 'test_accuracy': 0.7376580373269115, 'selected_test_score': 0.7434778246036524}


 45%|████▌     | 113/250 [01:23<01:40,  1.37it/s]

{'epoch': 113, 'valid_accuracy': 0.7449075158042613, 'best_valid_epoch': 113, 'best_valid_score': 0.7449075158042613, 'test_accuracy': 0.7436283363435682, 'selected_test_score': 0.7436283363435682}


 46%|████▌     | 114/250 [01:23<01:42,  1.33it/s]

{'epoch': 114, 'valid_accuracy': 0.7389370170920159, 'best_valid_epoch': 113, 'best_valid_score': 0.7449075158042613, 'test_accuracy': 0.7404174192253662, 'selected_test_score': 0.7436283363435682}


 46%|████▌     | 115/250 [01:24<01:40,  1.34it/s]

{'epoch': 115, 'valid_accuracy': 0.743034418169047, 'best_valid_epoch': 113, 'best_valid_score': 0.7449075158042613, 'test_accuracy': 0.7390628135661248, 'selected_test_score': 0.7436283363435682}


 46%|████▋     | 116/250 [01:25<01:39,  1.35it/s]

{'epoch': 116, 'valid_accuracy': 0.7409271833294311, 'best_valid_epoch': 113, 'best_valid_score': 0.7449075158042613, 'test_accuracy': 0.7433774834437086, 'selected_test_score': 0.7436283363435682}


 47%|████▋     | 117/250 [01:26<01:38,  1.36it/s]

{'epoch': 117, 'valid_accuracy': 0.7432685553734488, 'best_valid_epoch': 113, 'best_valid_score': 0.7449075158042613, 'test_accuracy': 0.7454344772225567, 'selected_test_score': 0.7436283363435682}


 47%|████▋     | 118/250 [01:26<01:37,  1.36it/s]

{'epoch': 118, 'valid_accuracy': 0.7460782018262702, 'best_valid_epoch': 118, 'best_valid_score': 0.7460782018262702, 'test_accuracy': 0.7445815773630343, 'selected_test_score': 0.7445815773630343}


 48%|████▊     | 119/250 [01:27<01:36,  1.36it/s]

{'epoch': 119, 'valid_accuracy': 0.7439709669866542, 'best_valid_epoch': 118, 'best_valid_score': 0.7460782018262702, 'test_accuracy': 0.745835841862332, 'selected_test_score': 0.7445815773630343}


 48%|████▊     | 120/250 [01:28<01:35,  1.36it/s]

{'epoch': 120, 'valid_accuracy': 0.746312339030672, 'best_valid_epoch': 120, 'best_valid_score': 0.746312339030672, 'test_accuracy': 0.7435781657635963, 'selected_test_score': 0.7435781657635963}


 48%|████▊     | 121/250 [01:28<01:34,  1.37it/s]

{'epoch': 121, 'valid_accuracy': 0.7432685553734488, 'best_valid_epoch': 120, 'best_valid_score': 0.746312339030672, 'test_accuracy': 0.7453341360626129, 'selected_test_score': 0.7435781657635963}


 49%|████▉     | 122/250 [01:29<01:33,  1.37it/s]

{'epoch': 122, 'valid_accuracy': 0.7453757902130649, 'best_valid_epoch': 120, 'best_valid_score': 0.746312339030672, 'test_accuracy': 0.7451334537427252, 'selected_test_score': 0.7435781657635963}


 49%|████▉     | 123/250 [01:30<01:32,  1.37it/s]

{'epoch': 123, 'valid_accuracy': 0.7426832123624444, 'best_valid_epoch': 120, 'best_valid_score': 0.746312339030672, 'test_accuracy': 0.7450331125827815, 'selected_test_score': 0.7435781657635963}


 50%|████▉     | 124/250 [01:31<01:32,  1.37it/s]

{'epoch': 124, 'valid_accuracy': 0.7464294076328729, 'best_valid_epoch': 124, 'best_valid_score': 0.7464294076328729, 'test_accuracy': 0.7454344772225567, 'selected_test_score': 0.7454344772225567}


 50%|█████     | 125/250 [01:31<01:31,  1.37it/s]

{'epoch': 125, 'valid_accuracy': 0.74736595645048, 'best_valid_epoch': 125, 'best_valid_score': 0.74736595645048, 'test_accuracy': 0.7450832831627534, 'selected_test_score': 0.7450832831627534}


 50%|█████     | 126/250 [01:32<01:33,  1.33it/s]

{'epoch': 126, 'valid_accuracy': 0.7454928588152657, 'best_valid_epoch': 125, 'best_valid_score': 0.74736595645048, 'test_accuracy': 0.746889424041742, 'selected_test_score': 0.7450832831627534}


 51%|█████     | 127/250 [01:33<01:31,  1.34it/s]

{'epoch': 127, 'valid_accuracy': 0.7401077031140249, 'best_valid_epoch': 125, 'best_valid_score': 0.74736595645048, 'test_accuracy': 0.7414208308248044, 'selected_test_score': 0.7450832831627534}


 51%|█████     | 128/250 [01:34<01:30,  1.35it/s]

{'epoch': 128, 'valid_accuracy': 0.7457269960196675, 'best_valid_epoch': 125, 'best_valid_score': 0.74736595645048, 'test_accuracy': 0.7475918121613486, 'selected_test_score': 0.7450832831627534}


 52%|█████▏    | 129/250 [01:34<01:29,  1.36it/s]

{'epoch': 129, 'valid_accuracy': 0.7474830250526808, 'best_valid_epoch': 129, 'best_valid_score': 0.7474830250526808, 'test_accuracy': 0.7458860124423038, 'selected_test_score': 0.7458860124423038}


 52%|█████▏    | 130/250 [01:35<01:28,  1.36it/s]

{'epoch': 130, 'valid_accuracy': 0.7466635448372746, 'best_valid_epoch': 129, 'best_valid_score': 0.7474830250526808, 'test_accuracy': 0.747943006221152, 'selected_test_score': 0.7458860124423038}


 52%|█████▏    | 131/250 [01:36<01:27,  1.36it/s]

{'epoch': 131, 'valid_accuracy': 0.7478342308592836, 'best_valid_epoch': 131, 'best_valid_score': 0.7478342308592836, 'test_accuracy': 0.746287377082079, 'selected_test_score': 0.746287377082079}


 53%|█████▎    | 132/250 [01:37<01:26,  1.36it/s]

{'epoch': 132, 'valid_accuracy': 0.746312339030672, 'best_valid_epoch': 131, 'best_valid_score': 0.7478342308592836, 'test_accuracy': 0.747943006221152, 'selected_test_score': 0.746287377082079}


 53%|█████▎    | 133/250 [01:37<01:25,  1.37it/s]

{'epoch': 133, 'valid_accuracy': 0.7474830250526808, 'best_valid_epoch': 131, 'best_valid_score': 0.7478342308592836, 'test_accuracy': 0.748394541440899, 'selected_test_score': 0.746287377082079}


 54%|█████▎    | 134/250 [01:38<01:24,  1.37it/s]

{'epoch': 134, 'valid_accuracy': 0.7467806134394755, 'best_valid_epoch': 131, 'best_valid_score': 0.7478342308592836, 'test_accuracy': 0.7480433473810957, 'selected_test_score': 0.746287377082079}


 54%|█████▍    | 135/250 [01:39<01:24,  1.37it/s]

{'epoch': 135, 'valid_accuracy': 0.7487707796768907, 'best_valid_epoch': 135, 'best_valid_score': 0.7487707796768907, 'test_accuracy': 0.746337547662051, 'selected_test_score': 0.746337547662051}


 54%|█████▍    | 136/250 [01:40<01:23,  1.37it/s]

{'epoch': 136, 'valid_accuracy': 0.7483025052680871, 'best_valid_epoch': 135, 'best_valid_score': 0.7487707796768907, 'test_accuracy': 0.7441802127232591, 'selected_test_score': 0.746337547662051}


 55%|█████▍    | 137/250 [01:40<01:22,  1.37it/s]

{'epoch': 137, 'valid_accuracy': 0.7478342308592836, 'best_valid_epoch': 135, 'best_valid_score': 0.7487707796768907, 'test_accuracy': 0.7492976118803933, 'selected_test_score': 0.746337547662051}


 55%|█████▌    | 138/250 [01:41<01:24,  1.33it/s]

{'epoch': 138, 'valid_accuracy': 0.7474830250526808, 'best_valid_epoch': 135, 'best_valid_score': 0.7487707796768907, 'test_accuracy': 0.7498996588400562, 'selected_test_score': 0.746337547662051}


 56%|█████▌    | 139/250 [01:42<01:22,  1.34it/s]

{'epoch': 139, 'valid_accuracy': 0.7478342308592836, 'best_valid_epoch': 135, 'best_valid_score': 0.7487707796768907, 'test_accuracy': 0.7497993176801124, 'selected_test_score': 0.746337547662051}


 56%|█████▌    | 140/250 [01:43<01:21,  1.35it/s]

{'epoch': 140, 'valid_accuracy': 0.7486537110746898, 'best_valid_epoch': 135, 'best_valid_score': 0.7487707796768907, 'test_accuracy': 0.7499498294200281, 'selected_test_score': 0.746337547662051}


 56%|█████▋    | 141/250 [01:43<01:20,  1.36it/s]

{'epoch': 141, 'valid_accuracy': 0.7485366424724889, 'best_valid_epoch': 135, 'best_valid_score': 0.7487707796768907, 'test_accuracy': 0.7506522175396347, 'selected_test_score': 0.746337547662051}


 57%|█████▋    | 142/250 [01:44<01:19,  1.36it/s]

{'epoch': 142, 'valid_accuracy': 0.7493561226878951, 'best_valid_epoch': 142, 'best_valid_score': 0.7493561226878951, 'test_accuracy': 0.7470901063616295, 'selected_test_score': 0.7470901063616295}


 57%|█████▋    | 143/250 [01:45<01:18,  1.36it/s]

{'epoch': 143, 'valid_accuracy': 0.748419573870288, 'best_valid_epoch': 142, 'best_valid_score': 0.7493561226878951, 'test_accuracy': 0.7503511940598033, 'selected_test_score': 0.7470901063616295}


 58%|█████▊    | 144/250 [01:45<01:17,  1.36it/s]

{'epoch': 144, 'valid_accuracy': 0.7452587216108639, 'best_valid_epoch': 142, 'best_valid_score': 0.7493561226878951, 'test_accuracy': 0.7480935179610676, 'selected_test_score': 0.7470901063616295}


 58%|█████▊    | 145/250 [01:46<01:16,  1.37it/s]

{'epoch': 145, 'valid_accuracy': 0.7507609459143058, 'best_valid_epoch': 145, 'best_valid_score': 0.7507609459143058, 'test_accuracy': 0.7496488059401967, 'selected_test_score': 0.7496488059401967}


 58%|█████▊    | 146/250 [01:47<01:16,  1.37it/s]

{'epoch': 146, 'valid_accuracy': 0.7507609459143058, 'best_valid_epoch': 145, 'best_valid_score': 0.7507609459143058, 'test_accuracy': 0.7485450531808148, 'selected_test_score': 0.7496488059401967}


 59%|█████▉    | 147/250 [01:48<01:15,  1.37it/s]

{'epoch': 147, 'valid_accuracy': 0.7493561226878951, 'best_valid_epoch': 145, 'best_valid_score': 0.7507609459143058, 'test_accuracy': 0.7517057997190447, 'selected_test_score': 0.7496488059401967}


 59%|█████▉    | 148/250 [01:48<01:14,  1.37it/s]

{'epoch': 148, 'valid_accuracy': 0.7507609459143058, 'best_valid_epoch': 145, 'best_valid_score': 0.7507609459143058, 'test_accuracy': 0.7509030704394943, 'selected_test_score': 0.7496488059401967}


 60%|█████▉    | 149/250 [01:49<01:13,  1.37it/s]

{'epoch': 149, 'valid_accuracy': 0.7516974947319129, 'best_valid_epoch': 149, 'best_valid_score': 0.7516974947319129, 'test_accuracy': 0.7513546056592414, 'selected_test_score': 0.7513546056592414}


 60%|██████    | 150/250 [01:50<01:15,  1.33it/s]

{'epoch': 150, 'valid_accuracy': 0.7516974947319129, 'best_valid_epoch': 149, 'best_valid_score': 0.7516974947319129, 'test_accuracy': 0.7496488059401967, 'selected_test_score': 0.7513546056592414}


 60%|██████    | 151/250 [01:51<01:13,  1.34it/s]

{'epoch': 151, 'valid_accuracy': 0.7511121517209085, 'best_valid_epoch': 149, 'best_valid_score': 0.7516974947319129, 'test_accuracy': 0.75210716435882, 'selected_test_score': 0.7513546056592414}


 61%|██████    | 152/250 [01:51<01:12,  1.35it/s]

{'epoch': 152, 'valid_accuracy': 0.7500585343011005, 'best_valid_epoch': 149, 'best_valid_score': 0.7516974947319129, 'test_accuracy': 0.7540638169777243, 'selected_test_score': 0.7513546056592414}


 61%|██████    | 153/250 [01:52<01:11,  1.36it/s]

{'epoch': 153, 'valid_accuracy': 0.7490049168812924, 'best_valid_epoch': 149, 'best_valid_score': 0.7516974947319129, 'test_accuracy': 0.7443307244631748, 'selected_test_score': 0.7513546056592414}


 62%|██████▏   | 154/250 [01:53<01:10,  1.36it/s]

{'epoch': 154, 'valid_accuracy': 0.7528681807539218, 'best_valid_epoch': 154, 'best_valid_score': 0.7528681807539218, 'test_accuracy': 0.7518061408789886, 'selected_test_score': 0.7518061408789886}


 62%|██████▏   | 155/250 [01:54<01:09,  1.36it/s]

{'epoch': 155, 'valid_accuracy': 0.7534535237649262, 'best_valid_epoch': 155, 'best_valid_score': 0.7534535237649262, 'test_accuracy': 0.7515552879791291, 'selected_test_score': 0.7515552879791291}


 62%|██████▏   | 156/250 [01:54<01:08,  1.36it/s]

{'epoch': 156, 'valid_accuracy': 0.7492390540856942, 'best_valid_epoch': 155, 'best_valid_score': 0.7534535237649262, 'test_accuracy': 0.7467890828817981, 'selected_test_score': 0.7515552879791291}


 63%|██████▎   | 157/250 [01:55<01:08,  1.37it/s]

{'epoch': 157, 'valid_accuracy': 0.7509950831187076, 'best_valid_epoch': 155, 'best_valid_score': 0.7534535237649262, 'test_accuracy': 0.7556692755368252, 'selected_test_score': 0.7515552879791291}


 63%|██████▎   | 158/250 [01:56<01:07,  1.37it/s]

{'epoch': 158, 'valid_accuracy': 0.7518145633341138, 'best_valid_epoch': 155, 'best_valid_score': 0.7534535237649262, 'test_accuracy': 0.7506020469596628, 'selected_test_score': 0.7515552879791291}


 64%|██████▎   | 159/250 [01:56<01:06,  1.37it/s]

{'epoch': 159, 'valid_accuracy': 0.7541559353781316, 'best_valid_epoch': 159, 'best_valid_score': 0.7541559353781316, 'test_accuracy': 0.7543648404575557, 'selected_test_score': 0.7543648404575557}


 64%|██████▍   | 160/250 [01:57<01:05,  1.37it/s]

{'epoch': 160, 'valid_accuracy': 0.751580426129712, 'best_valid_epoch': 159, 'best_valid_score': 0.7541559353781316, 'test_accuracy': 0.7506020469596628, 'selected_test_score': 0.7543648404575557}


 64%|██████▍   | 161/250 [01:58<01:05,  1.37it/s]

{'epoch': 161, 'valid_accuracy': 0.7519316319363146, 'best_valid_epoch': 159, 'best_valid_score': 0.7541559353781316, 'test_accuracy': 0.75105358217941, 'selected_test_score': 0.7543648404575557}


 65%|██████▍   | 162/250 [01:59<01:06,  1.33it/s]

{'epoch': 162, 'valid_accuracy': 0.7546242097869351, 'best_valid_epoch': 162, 'best_valid_score': 0.7546242097869351, 'test_accuracy': 0.7522075055187638, 'selected_test_score': 0.7522075055187638}


 65%|██████▌   | 163/250 [01:59<01:04,  1.34it/s]

{'epoch': 163, 'valid_accuracy': 0.7562631702177476, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7550672285771624, 'selected_test_score': 0.7550672285771624}


 66%|██████▌   | 164/250 [02:00<01:03,  1.35it/s]

{'epoch': 164, 'valid_accuracy': 0.7539217981737298, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7575757575757576, 'selected_test_score': 0.7550672285771624}


 66%|██████▌   | 165/250 [02:01<01:02,  1.35it/s]

{'epoch': 165, 'valid_accuracy': 0.7514633575275111, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7492474413004214, 'selected_test_score': 0.7550672285771624}


 66%|██████▋   | 166/250 [02:02<01:01,  1.36it/s]

{'epoch': 166, 'valid_accuracy': 0.755794895808944, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7561208107565723, 'selected_test_score': 0.7550672285771624}


 67%|██████▋   | 167/250 [02:02<01:00,  1.36it/s]

{'epoch': 167, 'valid_accuracy': 0.7540388667759307, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7605859923740719, 'selected_test_score': 0.7550672285771624}


 67%|██████▋   | 168/250 [02:03<01:00,  1.36it/s]

{'epoch': 168, 'valid_accuracy': 0.7552095527979396, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7528095524784266, 'selected_test_score': 0.7550672285771624}


 68%|██████▊   | 169/250 [02:04<00:59,  1.36it/s]

{'epoch': 169, 'valid_accuracy': 0.7546242097869351, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7596829219345775, 'selected_test_score': 0.7550672285771624}


 68%|██████▊   | 170/250 [02:05<00:58,  1.37it/s]

{'epoch': 170, 'valid_accuracy': 0.7556778272067431, 'best_valid_epoch': 163, 'best_valid_score': 0.7562631702177476, 'test_accuracy': 0.7581778045354204, 'selected_test_score': 0.7550672285771624}


 68%|██████▊   | 171/250 [02:05<00:57,  1.37it/s]

{'epoch': 171, 'valid_accuracy': 0.7566143760243502, 'best_valid_epoch': 171, 'best_valid_score': 0.7566143760243502, 'test_accuracy': 0.7577262693156733, 'selected_test_score': 0.7577262693156733}


 69%|██████▉   | 172/250 [02:06<00:57,  1.37it/s]

{'epoch': 172, 'valid_accuracy': 0.7589557480683681, 'best_valid_epoch': 172, 'best_valid_score': 0.7589557480683681, 'test_accuracy': 0.7551675697371062, 'selected_test_score': 0.7551675697371062}


 69%|██████▉   | 173/250 [02:07<00:56,  1.37it/s]

{'epoch': 173, 'valid_accuracy': 0.7525169749473192, 'best_valid_epoch': 172, 'best_valid_score': 0.7589557480683681, 'test_accuracy': 0.7509532410194661, 'selected_test_score': 0.7551675697371062}


 70%|██████▉   | 174/250 [02:08<00:57,  1.33it/s]

{'epoch': 174, 'valid_accuracy': 0.7590728166705689, 'best_valid_epoch': 174, 'best_valid_score': 0.7590728166705689, 'test_accuracy': 0.7553180814770218, 'selected_test_score': 0.7553180814770218}


 70%|███████   | 175/250 [02:08<00:55,  1.34it/s]

{'epoch': 175, 'valid_accuracy': 0.7509950831187076, 'best_valid_epoch': 174, 'best_valid_score': 0.7590728166705689, 'test_accuracy': 0.7490969295605057, 'selected_test_score': 0.7553180814770218}


 70%|███████   | 176/250 [02:09<00:54,  1.35it/s]

{'epoch': 176, 'valid_accuracy': 0.7581362678529618, 'best_valid_epoch': 174, 'best_valid_score': 0.7590728166705689, 'test_accuracy': 0.7585791691751956, 'selected_test_score': 0.7553180814770218}


 71%|███████   | 177/250 [02:10<00:53,  1.36it/s]

{'epoch': 177, 'valid_accuracy': 0.7587216108639663, 'best_valid_epoch': 174, 'best_valid_score': 0.7590728166705689, 'test_accuracy': 0.7556692755368252, 'selected_test_score': 0.7553180814770218}


 71%|███████   | 178/250 [02:10<00:52,  1.36it/s]

{'epoch': 178, 'valid_accuracy': 0.7577850620463592, 'best_valid_epoch': 174, 'best_valid_score': 0.7590728166705689, 'test_accuracy': 0.7637467389123018, 'selected_test_score': 0.7553180814770218}


 72%|███████▏  | 179/250 [02:11<00:52,  1.36it/s]

{'epoch': 179, 'valid_accuracy': 0.7595410910793725, 'best_valid_epoch': 179, 'best_valid_score': 0.7595410910793725, 'test_accuracy': 0.7559702990166566, 'selected_test_score': 0.7559702990166566}


 72%|███████▏  | 180/250 [02:12<00:51,  1.36it/s]

{'epoch': 180, 'valid_accuracy': 0.7570826504331538, 'best_valid_epoch': 179, 'best_valid_score': 0.7595410910793725, 'test_accuracy': 0.7621412803532008, 'selected_test_score': 0.7559702990166566}


 72%|███████▏  | 181/250 [02:13<00:50,  1.37it/s]

{'epoch': 181, 'valid_accuracy': 0.7570826504331538, 'best_valid_epoch': 179, 'best_valid_score': 0.7595410910793725, 'test_accuracy': 0.7549167168372466, 'selected_test_score': 0.7559702990166566}


 73%|███████▎  | 182/250 [02:13<00:49,  1.37it/s]

{'epoch': 182, 'valid_accuracy': 0.7582533364551627, 'best_valid_epoch': 179, 'best_valid_score': 0.7595410910793725, 'test_accuracy': 0.7556692755368252, 'selected_test_score': 0.7559702990166566}


 73%|███████▎  | 183/250 [02:14<00:48,  1.37it/s]

{'epoch': 183, 'valid_accuracy': 0.7596581596815734, 'best_valid_epoch': 183, 'best_valid_score': 0.7596581596815734, 'test_accuracy': 0.7640979329721052, 'selected_test_score': 0.7640979329721052}


 74%|███████▎  | 184/250 [02:15<00:48,  1.37it/s]

{'epoch': 184, 'valid_accuracy': 0.7589557480683681, 'best_valid_epoch': 183, 'best_valid_score': 0.7596581596815734, 'test_accuracy': 0.7564720048163757, 'selected_test_score': 0.7640979329721052}


 74%|███████▍  | 185/250 [02:16<00:47,  1.37it/s]

{'epoch': 185, 'valid_accuracy': 0.7605947084991805, 'best_valid_epoch': 185, 'best_valid_score': 0.7605947084991805, 'test_accuracy': 0.7654525386313465, 'selected_test_score': 0.7654525386313465}


 74%|███████▍  | 186/250 [02:16<00:48,  1.33it/s]

{'epoch': 186, 'valid_accuracy': 0.7574338562397565, 'best_valid_epoch': 185, 'best_valid_score': 0.7605947084991805, 'test_accuracy': 0.7558197872767409, 'selected_test_score': 0.7654525386313465}


 75%|███████▍  | 187/250 [02:17<00:46,  1.34it/s]

{'epoch': 187, 'valid_accuracy': 0.7614141887145868, 'best_valid_epoch': 187, 'best_valid_score': 0.7614141887145868, 'test_accuracy': 0.7658037326911499, 'selected_test_score': 0.7658037326911499}


 75%|███████▌  | 188/250 [02:18<00:45,  1.35it/s]

{'epoch': 188, 'valid_accuracy': 0.7602435026925779, 'best_valid_epoch': 187, 'best_valid_score': 0.7614141887145868, 'test_accuracy': 0.758880192655027, 'selected_test_score': 0.7658037326911499}


 76%|███████▌  | 189/250 [02:19<00:45,  1.36it/s]

{'epoch': 189, 'valid_accuracy': 0.7598922968859751, 'best_valid_epoch': 187, 'best_valid_score': 0.7614141887145868, 'test_accuracy': 0.75737507525587, 'selected_test_score': 0.7658037326911499}


 76%|███████▌  | 190/250 [02:19<00:44,  1.36it/s]

{'epoch': 190, 'valid_accuracy': 0.7598922968859751, 'best_valid_epoch': 187, 'best_valid_score': 0.7614141887145868, 'test_accuracy': 0.7611880393337347, 'selected_test_score': 0.7658037326911499}


 76%|███████▋  | 191/250 [02:20<00:43,  1.36it/s]

{'epoch': 191, 'valid_accuracy': 0.7631702177476001, 'best_valid_epoch': 191, 'best_valid_score': 0.7631702177476001, 'test_accuracy': 0.7676098735701384, 'selected_test_score': 0.7676098735701384}


 77%|███████▋  | 192/250 [02:21<00:42,  1.36it/s]

{'epoch': 192, 'valid_accuracy': 0.7488878482790915, 'best_valid_epoch': 191, 'best_valid_score': 0.7631702177476001, 'test_accuracy': 0.7527593818984547, 'selected_test_score': 0.7676098735701384}


 77%|███████▋  | 193/250 [02:22<00:41,  1.36it/s]

{'epoch': 193, 'valid_accuracy': 0.7641067665652073, 'best_valid_epoch': 193, 'best_valid_score': 0.7641067665652073, 'test_accuracy': 0.7687136263295203, 'selected_test_score': 0.7687136263295203}


 78%|███████▊  | 194/250 [02:22<00:41,  1.37it/s]

{'epoch': 194, 'valid_accuracy': 0.7631702177476001, 'best_valid_epoch': 193, 'best_valid_score': 0.7641067665652073, 'test_accuracy': 0.767308850090307, 'selected_test_score': 0.7687136263295203}


 78%|███████▊  | 195/250 [02:23<00:40,  1.37it/s]

{'epoch': 195, 'valid_accuracy': 0.7609459143057832, 'best_valid_epoch': 193, 'best_valid_score': 0.7641067665652073, 'test_accuracy': 0.7665562913907285, 'selected_test_score': 0.7687136263295203}


 78%|███████▊  | 196/250 [02:24<00:39,  1.37it/s]

{'epoch': 196, 'valid_accuracy': 0.7564973074221494, 'best_valid_epoch': 193, 'best_valid_score': 0.7641067665652073, 'test_accuracy': 0.7554184226369657, 'selected_test_score': 0.7687136263295203}


 79%|███████▉  | 197/250 [02:24<00:38,  1.37it/s]

{'epoch': 197, 'valid_accuracy': 0.761180051510185, 'best_valid_epoch': 193, 'best_valid_score': 0.7641067665652073, 'test_accuracy': 0.7624423038330324, 'selected_test_score': 0.7687136263295203}


 79%|███████▉  | 198/250 [02:25<00:39,  1.33it/s]

{'epoch': 198, 'valid_accuracy': 0.7612971201123858, 'best_valid_epoch': 193, 'best_valid_score': 0.7641067665652073, 'test_accuracy': 0.7601344571543247, 'selected_test_score': 0.7687136263295203}


 80%|███████▉  | 199/250 [02:26<00:38,  1.34it/s]

{'epoch': 199, 'valid_accuracy': 0.7642238351674081, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7690146498093517, 'selected_test_score': 0.7690146498093517}


 80%|████████  | 200/250 [02:27<00:37,  1.35it/s]

{'epoch': 200, 'valid_accuracy': 0.7538047295715289, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7517057997190447, 'selected_test_score': 0.7690146498093517}


 80%|████████  | 201/250 [02:27<00:36,  1.35it/s]

{'epoch': 201, 'valid_accuracy': 0.7602435026925779, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7621914509331728, 'selected_test_score': 0.7690146498093517}


 81%|████████  | 202/250 [02:28<00:35,  1.36it/s]

{'epoch': 202, 'valid_accuracy': 0.7575509248419574, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7550672285771624, 'selected_test_score': 0.7690146498093517}


 81%|████████  | 203/250 [02:29<00:34,  1.36it/s]

{'epoch': 203, 'valid_accuracy': 0.7552095527979396, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7525586995785671, 'selected_test_score': 0.7690146498093517}


 82%|████████▏ | 204/250 [02:30<00:33,  1.36it/s]

{'epoch': 204, 'valid_accuracy': 0.7600093654881761, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7624423038330324, 'selected_test_score': 0.7690146498093517}


 82%|████████▏ | 205/250 [02:30<00:32,  1.37it/s]

{'epoch': 205, 'valid_accuracy': 0.760126434090377, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7629941802127232, 'selected_test_score': 0.7690146498093517}


 82%|████████▏ | 206/250 [02:31<00:32,  1.37it/s]

{'epoch': 206, 'valid_accuracy': 0.7590728166705689, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7606361629540438, 'selected_test_score': 0.7690146498093517}


 83%|████████▎ | 207/250 [02:32<00:31,  1.37it/s]

{'epoch': 207, 'valid_accuracy': 0.7618824631233903, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7641481035520771, 'selected_test_score': 0.7690146498093517}


 83%|████████▎ | 208/250 [02:33<00:30,  1.37it/s]

{'epoch': 208, 'valid_accuracy': 0.758019199250761, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7561208107565723, 'selected_test_score': 0.7690146498093517}


 84%|████████▎ | 209/250 [02:33<00:29,  1.37it/s]

{'epoch': 209, 'valid_accuracy': 0.7584874736595645, 'best_valid_epoch': 199, 'best_valid_score': 0.7642238351674081, 'test_accuracy': 0.7596327513546056, 'selected_test_score': 0.7690146498093517}


 84%|████████▍ | 210/250 [02:34<00:30,  1.33it/s]

{'epoch': 210, 'valid_accuracy': 0.7658627955982206, 'best_valid_epoch': 210, 'best_valid_score': 0.7658627955982206, 'test_accuracy': 0.7668573148705599, 'selected_test_score': 0.7668573148705599}


 84%|████████▍ | 211/250 [02:35<00:29,  1.34it/s]

{'epoch': 211, 'valid_accuracy': 0.7682041676422383, 'best_valid_epoch': 211, 'best_valid_score': 0.7682041676422383, 'test_accuracy': 0.7723259080874975, 'selected_test_score': 0.7723259080874975}


 85%|████████▍ | 212/250 [02:36<00:28,  1.35it/s]

{'epoch': 212, 'valid_accuracy': 0.7679700304378366, 'best_valid_epoch': 211, 'best_valid_score': 0.7682041676422383, 'test_accuracy': 0.772075055187638, 'selected_test_score': 0.7723259080874975}


 85%|████████▌ | 213/250 [02:36<00:27,  1.36it/s]

{'epoch': 213, 'valid_accuracy': 0.7663310700070242, 'best_valid_epoch': 211, 'best_valid_score': 0.7682041676422383, 'test_accuracy': 0.7680112382099137, 'selected_test_score': 0.7723259080874975}


 86%|████████▌ | 214/250 [02:37<00:26,  1.36it/s]

{'epoch': 214, 'valid_accuracy': 0.7694919222664481, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7711719847481436, 'selected_test_score': 0.7711719847481436}


 86%|████████▌ | 215/250 [02:38<00:25,  1.36it/s]

{'epoch': 215, 'valid_accuracy': 0.7575509248419574, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7552177403170781, 'selected_test_score': 0.7711719847481436}


 86%|████████▋ | 216/250 [02:38<00:24,  1.36it/s]

{'epoch': 216, 'valid_accuracy': 0.7598922968859751, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7629941802127232, 'selected_test_score': 0.7711719847481436}


 87%|████████▋ | 217/250 [02:39<00:24,  1.37it/s]

{'epoch': 217, 'valid_accuracy': 0.7583704050573636, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7581276339554486, 'selected_test_score': 0.7711719847481436}


 87%|████████▋ | 218/250 [02:40<00:23,  1.37it/s]

{'epoch': 218, 'valid_accuracy': 0.7671505502224304, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7685631145896047, 'selected_test_score': 0.7711719847481436}


 88%|████████▊ | 219/250 [02:41<00:22,  1.37it/s]

{'epoch': 219, 'valid_accuracy': 0.7564973074221494, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7552679108970499, 'selected_test_score': 0.7711719847481436}


 88%|████████▊ | 220/250 [02:41<00:21,  1.37it/s]

{'epoch': 220, 'valid_accuracy': 0.7581362678529618, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7613887216536224, 'selected_test_score': 0.7711719847481436}


 88%|████████▊ | 221/250 [02:42<00:21,  1.37it/s]

{'epoch': 221, 'valid_accuracy': 0.7583704050573636, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7600842865743528, 'selected_test_score': 0.7711719847481436}


 89%|████████▉ | 222/250 [02:43<00:21,  1.33it/s]

{'epoch': 222, 'valid_accuracy': 0.7583704050573636, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7618402568733694, 'selected_test_score': 0.7711719847481436}


 89%|████████▉ | 223/250 [02:44<00:20,  1.34it/s]

{'epoch': 223, 'valid_accuracy': 0.7618824631233903, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7652016857314871, 'selected_test_score': 0.7711719847481436}


 90%|████████▉ | 224/250 [02:44<00:19,  1.35it/s]

{'epoch': 224, 'valid_accuracy': 0.7586045422617654, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7601846277342966, 'selected_test_score': 0.7711719847481436}


 90%|█████████ | 225/250 [02:45<00:18,  1.36it/s]

{'epoch': 225, 'valid_accuracy': 0.764340903769609, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7679610676299418, 'selected_test_score': 0.7711719847481436}


 90%|█████████ | 226/250 [02:46<00:17,  1.36it/s]

{'epoch': 226, 'valid_accuracy': 0.7470147506438773, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7470901063616295, 'selected_test_score': 0.7711719847481436}


 91%|█████████ | 227/250 [02:47<00:16,  1.36it/s]

{'epoch': 227, 'valid_accuracy': 0.7656286583938188, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7667569737106161, 'selected_test_score': 0.7711719847481436}


 91%|█████████ | 228/250 [02:47<00:16,  1.36it/s]

{'epoch': 228, 'valid_accuracy': 0.7538047295715289, 'best_valid_epoch': 214, 'best_valid_score': 0.7694919222664481, 'test_accuracy': 0.7532610876981738, 'selected_test_score': 0.7711719847481436}


 92%|█████████▏| 229/250 [02:48<00:15,  1.37it/s]

{'epoch': 229, 'valid_accuracy': 0.7705455396862562, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.7738811960666265, 'selected_test_score': 0.7738811960666265}


 92%|█████████▏| 230/250 [02:49<00:14,  1.37it/s]

{'epoch': 230, 'valid_accuracy': 0.7595410910793725, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.7622917920931166, 'selected_test_score': 0.7738811960666265}


 92%|█████████▏| 231/250 [02:49<00:13,  1.37it/s]

{'epoch': 231, 'valid_accuracy': 0.762116600327792, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.765251856311459, 'selected_test_score': 0.7738811960666265}


 93%|█████████▎| 232/250 [02:50<00:13,  1.37it/s]

{'epoch': 232, 'valid_accuracy': 0.7586045422617654, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.7609873570138471, 'selected_test_score': 0.7738811960666265}


 93%|█████████▎| 233/250 [02:51<00:12,  1.37it/s]

{'epoch': 233, 'valid_accuracy': 0.7699601966752517, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.7742323901264299, 'selected_test_score': 0.7738811960666265}


 94%|█████████▎| 234/250 [02:52<00:12,  1.33it/s]

{'epoch': 234, 'valid_accuracy': 0.7587216108639663, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.7603351394742123, 'selected_test_score': 0.7738811960666265}


 94%|█████████▍| 235/250 [02:52<00:11,  1.34it/s]

{'epoch': 235, 'valid_accuracy': 0.7683212362444393, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.7692655027092113, 'selected_test_score': 0.7738811960666265}


 94%|█████████▍| 236/250 [02:53<00:10,  1.35it/s]

{'epoch': 236, 'valid_accuracy': 0.7541559353781316, 'best_valid_epoch': 229, 'best_valid_score': 0.7705455396862562, 'test_accuracy': 0.7540638169777243, 'selected_test_score': 0.7738811960666265}


 95%|█████████▍| 237/250 [02:54<00:09,  1.35it/s]

{'epoch': 237, 'valid_accuracy': 0.7732381175368767, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7750852899859523, 'selected_test_score': 0.7750852899859523}


 95%|█████████▌| 238/250 [02:55<00:08,  1.36it/s]

{'epoch': 238, 'valid_accuracy': 0.7604776398969796, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7626931567328918, 'selected_test_score': 0.7750852899859523}


 96%|█████████▌| 239/250 [02:55<00:08,  1.36it/s]

{'epoch': 239, 'valid_accuracy': 0.7711308826972606, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7754364840457556, 'selected_test_score': 0.7750852899859523}


 96%|█████████▌| 240/250 [02:56<00:07,  1.36it/s]

{'epoch': 240, 'valid_accuracy': 0.7538047295715289, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7530102347983143, 'selected_test_score': 0.7750852899859523}


 96%|█████████▋| 241/250 [02:57<00:06,  1.36it/s]

{'epoch': 241, 'valid_accuracy': 0.7682041676422383, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7707706201083685, 'selected_test_score': 0.7750852899859523}


 97%|█████████▋| 242/250 [02:58<00:05,  1.37it/s]

{'epoch': 242, 'valid_accuracy': 0.7603605712947787, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7641481035520771, 'selected_test_score': 0.7750852899859523}


 97%|█████████▋| 243/250 [02:58<00:05,  1.37it/s]

{'epoch': 243, 'valid_accuracy': 0.7720674315148677, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.772024884607666, 'selected_test_score': 0.7750852899859523}


 98%|█████████▊| 244/250 [02:59<00:04,  1.37it/s]

{'epoch': 244, 'valid_accuracy': 0.7711308826972606, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7752859723058398, 'selected_test_score': 0.7750852899859523}


 98%|█████████▊| 245/250 [03:00<00:03,  1.37it/s]

{'epoch': 245, 'valid_accuracy': 0.770662608288457, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7761388721653623, 'selected_test_score': 0.7750852899859523}


 98%|█████████▊| 246/250 [03:01<00:03,  1.33it/s]

{'epoch': 246, 'valid_accuracy': 0.7619995317255912, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7659542444310656, 'selected_test_score': 0.7750852899859523}


 99%|█████████▉| 247/250 [03:01<00:02,  1.34it/s]

{'epoch': 247, 'valid_accuracy': 0.7724186373214704, 'best_valid_epoch': 237, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7752358017258679, 'selected_test_score': 0.7750852899859523}


 99%|█████████▉| 248/250 [03:02<00:01,  1.35it/s]

{'epoch': 248, 'valid_accuracy': 0.7733551861390775, 'best_valid_epoch': 248, 'best_valid_score': 0.7733551861390775, 'test_accuracy': 0.7729781256271322, 'selected_test_score': 0.7729781256271322}


100%|█████████▉| 249/250 [03:03<00:00,  1.35it/s]

{'epoch': 249, 'valid_accuracy': 0.7553266214001405, 'best_valid_epoch': 248, 'best_valid_score': 0.7733551861390775, 'test_accuracy': 0.7543648404575557, 'selected_test_score': 0.7729781256271322}


100%|██████████| 250/250 [03:04<00:00,  1.36it/s]

{'epoch': 250, 'valid_accuracy': 0.7577850620463592, 'best_valid_epoch': 248, 'best_valid_score': 0.7733551861390775, 'test_accuracy': 0.7603351394742123, 'selected_test_score': 0.7729781256271322}





In [75]:
model

DOFENClassifier(
  (condition_generation): ConditionGeneration(
    (phi_1): ModuleDict(
      (num): Sequential(
        (0): Reshape()
        (1): FastGroupConv1d(7, 448, kernel_size=(1,), stride=(1,), groups=7)
        (2): Sigmoid()
        (3): Reshape()
      )
    )
  )
  (proj): Linear(in_features=1, out_features=16, bias=True)
  (rodt_construction): rODTConstruction()
  (norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (attn): MultiheadAttention(
    (qw_proj): Linear(in_features=16, out_features=16, bias=True)
    (kw_proj): Linear(in_features=16, out_features=16, bias=True)
    (vw_proj): Linear(in_features=16, out_features=16, bias=True)
    (ow_proj): Linear(in_features=16, out_features=1, bias=True)
    (qE_proj): Linear(in_features=16, out_features=16, bias=True)
    (kE_proj): Linear(in_features=16, out_features=16, bias=True)
    (vE_proj): Linear(in_features=16, out_features=16, bias=True)
    (oE_proj): Linear(in_features=16, out_features=16, bias=True)

In [63]:
model.attn

MultiheadAttention(
  (qw_proj): Linear(in_features=16, out_features=16, bias=True)
  (kw_proj): Linear(in_features=16, out_features=16, bias=True)
  (vw_proj): Linear(in_features=16, out_features=16, bias=True)
  (ow_proj): Linear(in_features=16, out_features=1, bias=True)
  (qE_proj): Linear(in_features=16, out_features=16, bias=True)
  (kE_proj): Linear(in_features=16, out_features=16, bias=True)
  (vE_proj): Linear(in_features=16, out_features=16, bias=True)
  (oE_proj): Linear(in_features=16, out_features=16, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [64]:
model.cuda()(torch.tensor(train_X[:256]).cuda())

{'pred': tensor([[ 0.2464,  0.0590],
         [-0.5380,  1.1692],
         [ 0.2164,  0.1967],
         [ 0.6682, -0.6180],
         [ 1.0577, -1.0969],
         [ 0.6896, -0.6419],
         [ 1.0186, -1.0187],
         [ 1.1412, -1.2141],
         [ 0.3392, -0.0855],
         [ 0.9673, -1.0927],
         [ 0.5967, -0.4012],
         [-0.3925,  1.0767],
         [ 1.0105, -1.0750],
         [-0.9391,  1.6181],
         [ 0.1245,  0.2889],
         [ 0.3975, -0.1869],
         [ 0.3330, -0.0569],
         [ 0.2085,  0.1799],
         [ 0.3003, -0.0284],
         [-0.8618,  1.4978],
         [-0.9471,  1.6148],
         [ 0.9647, -1.0814],
         [ 0.6192, -0.5106],
         [-0.8640,  1.5417],
         [ 0.4123, -0.0875],
         [ 0.9487, -1.0439],
         [ 0.2228,  0.1897],
         [-0.5909,  1.2093],
         [-0.1680,  0.7478],
         [ 0.7082, -0.6771],
         [ 0.8036, -0.7792],
         [ 0.0167,  0.4327],
         [ 1.0330, -1.1069],
         [ 1.2381, -1.3095],
      

In [65]:
activations = {}

def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

handle = model.get_submodule("attn.oE_proj").register_forward_hook(get_activation("oE_proj"))

In [66]:
model.cuda()(torch.tensor(train_X[:256]).cuda())

{'pred': tensor([[ 0.2464,  0.0590],
         [-0.5380,  1.1692],
         [ 0.2164,  0.1967],
         [ 0.6682, -0.6180],
         [ 1.0577, -1.0969],
         [ 0.6896, -0.6419],
         [ 1.0186, -1.0187],
         [ 1.1412, -1.2141],
         [ 0.3392, -0.0855],
         [ 0.9673, -1.0927],
         [ 0.5967, -0.4012],
         [-0.3925,  1.0767],
         [ 1.0105, -1.0750],
         [-0.9391,  1.6181],
         [ 0.1245,  0.2889],
         [ 0.3975, -0.1869],
         [ 0.3330, -0.0569],
         [ 0.2085,  0.1799],
         [ 0.3003, -0.0284],
         [-0.8618,  1.4978],
         [-0.9471,  1.6148],
         [ 0.9647, -1.0814],
         [ 0.6192, -0.5106],
         [-0.8640,  1.5417],
         [ 0.4123, -0.0875],
         [ 0.9487, -1.0439],
         [ 0.2228,  0.1897],
         [-0.5909,  1.2093],
         [-0.1680,  0.7478],
         [ 0.7082, -0.6771],
         [ 0.8036, -0.7792],
         [ 0.0167,  0.4327],
         [ 1.0330, -1.1069],
         [ 1.2381, -1.3095],
      

In [67]:
oE_proj_output = activations["oE_proj"]
print(oE_proj_output.shape)

torch.Size([28672, 4, 16])


In [68]:
oE_proj_output = oE_proj_output.reshape(256, -1, 4, 16)

In [69]:
oE_proj_output.shape

torch.Size([256, 112, 4, 16])

In [70]:
x = oE_proj_output[0].view(112, -1)

In [71]:
x_ = x / ((x ** 2).sum(dim=1) ** 0.5)[:, None]
sim = x_ @ x_.T
sim

tensor([[ 1.0000,  0.9992,  0.9958,  ..., -0.9790,  0.9919,  0.8513],
        [ 0.9992,  1.0000,  0.9965,  ..., -0.9790,  0.9906,  0.8464],
        [ 0.9958,  0.9965,  1.0000,  ..., -0.9633,  0.9842,  0.8406],
        ...,
        [-0.9790, -0.9790, -0.9633,  ...,  1.0000, -0.9733, -0.8030],
        [ 0.9919,  0.9906,  0.9842,  ..., -0.9733,  1.0000,  0.9014],
        [ 0.8513,  0.8464,  0.8406,  ..., -0.8030,  0.9014,  1.0000]],
       device='cuda:0')

In [72]:
sim[0]

tensor([ 1.0000,  0.9992,  0.9958, -0.6870,  0.9765,  0.9838, -0.9893,  0.9834,
        -0.9766, -0.8133, -0.9826,  0.8599,  0.9035, -0.9912, -0.9935,  0.9223,
        -0.9852,  0.9703,  0.9482,  0.8411,  0.3434,  0.9618,  0.9940, -0.3987,
         0.9868, -0.8418,  0.9356,  0.9157, -0.9411,  0.9697,  0.1964, -0.4948,
        -0.0139,  0.9717, -0.9834,  0.8717, -0.9824, -0.9822,  0.9630,  0.9145,
         0.9693,  0.9978,  0.6707,  0.9900, -0.9802, -0.9811,  0.7056, -0.9843,
        -0.9906, -0.9753, -0.9456,  0.9859, -0.9421, -0.4298,  0.9932,  0.8508,
        -0.9796, -0.2989, -0.9884,  0.4897,  0.9947,  0.9717,  0.9943, -0.6062,
        -0.9369,  0.9781, -0.9854, -0.5740,  0.9708,  0.9985, -0.8017,  0.9351,
         0.8686,  0.9973,  0.9785,  0.9805,  0.9970, -0.9857,  0.9885, -0.9731,
         0.6744, -0.8712,  0.9381,  0.9888,  0.9916, -0.9768,  0.9922, -0.9572,
         0.8278,  0.9864,  0.3636,  0.9506, -0.9924,  0.9902,  0.6082,  0.9664,
        -0.5141, -0.0721,  0.9563, -0.98