In [1]:
from pathlib import Path

from src.metric import cls_eval_funs, reg_eval_funs
from src.trainer import Trainer
from src.utils import load_pkl_data, set_random_seed

In [13]:
class Config:
    data_dir = '/home/jiawei/Desktop/github/DOFEN/tabular-benchmark/tabular_benchmark_data'
    data_id = '361060'
    n_epoch = 250
    batch_size = 256
    target_transform = False

In [85]:
from typing import Dict, List, Optional

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from src.models.base import BaseClassifier, BaseRegressor
from src.models.layers import ResidualLayer, ReGLU


class Reshape(nn.Module):
    def __init__(self, *args: int) -> None:
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.reshape(self.shape)


class FastGroupConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        self.fast_mode = kwargs.pop('fast_mode')
        nn.Conv1d.__init__(self, *args, **kwargs)
        if self.groups > self.fast_mode:
            self.weight = nn.Parameter(
                self.weight.reshape(
                    self.groups, self.out_channels // self.groups, self.in_channels // self.groups, 1
                ).permute(3, 0, 2, 1)
            )
            self.bias = nn.Parameter(
                self.bias.unsqueeze(0).unsqueeze(-1)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.groups > self.fast_mode:
            x = x.reshape(-1, self.groups, self.in_channels // self.groups, 1)
            return (x * self.weight).sum(2, keepdims=True).permute(0, 1, 3, 2).reshape(-1, self.out_channels, 1) + self.bias
        else:
            return self._conv_forward(x, self.weight, self.bias)


class ConditionGeneration(nn.Module):
    def __init__(
        self,
        category_column_count: List[int],
        n_cond: int = 128,
        categorical_optimized: bool = False,
        fast_mode: int = 64,
    ):
        super(ConditionGeneration, self).__init__()
        self.fast_mode = fast_mode
        self.categorical_optimized = categorical_optimized

        index_info = self.extract_feature_metadata(category_column_count)
        self.numerical_index = index_info['numerical_index']
        self.categorical_index = index_info['categorical_index']
        self.categorical_count = index_info['categorical_count']

        categorical_offset = torch.tensor([0] + np.cumsum(self.categorical_count).tolist()[:-1]).long()
        self.register_buffer('categorical_offset', categorical_offset)

        self.n_cond = n_cond
        self.phi_1 = self.get_phi_1()

    def extract_feature_metadata(self, category_column_count: List[int]) -> Dict[str, List[int]]:
        numerical_index = [i for i, count in enumerate(category_column_count) if count == -1]
        categorical_index = [i for i, count in enumerate(category_column_count) if count != -1]
        categorical_count = [count for count in category_column_count if count != -1]
        return {
            'numerical_index': numerical_index,
            'categorical_index': categorical_index,
            'categorical_count': categorical_count,
        }

    def get_phi_1(self) -> nn.ModuleDict:
        phi_1 = nn.ModuleDict()
        if len(self.numerical_index):
            phi_1['num'] = nn.Sequential(
                # input = (b, n_num_col)
                # output = (b, n_num_col, n_cond)
                Reshape(-1, len(self.numerical_index), 1),
                FastGroupConv1d(
                    len(self.numerical_index),
                    len(self.numerical_index) * self.n_cond,
                    kernel_size=1,
                    groups=len(self.numerical_index),
                    fast_mode=self.fast_mode
                ), # (b, n_num_col, 1) -> (b, n_num_col * n_cond, 1)
                nn.Sigmoid(),
                Reshape(-1, len(self.numerical_index), self.n_cond)
            )
        if len(self.categorical_index):
            phi_1['cat'] = nn.ModuleDict()
            phi_1['cat']['embedder'] = nn.Embedding(sum(self.categorical_count), self.n_cond)            
            phi_1['cat']['mapper'] = nn.Sequential(
                # input = (b, n_cat_col, n_cond)
                # output = (b, n_cat_col, n_cond)
                Reshape(-1, len(self.categorical_index) * self.n_cond, 1),
                nn.GroupNorm(len(self.categorical_index), len(self.categorical_index) * self.n_cond),
                FastGroupConv1d(len(self.categorical_index) * self.n_cond, len(self.categorical_index)*self.n_cond, kernel_size=1, groups=len(self.categorical_index)*self.n_cond if self.categorical_optimized else len(self.categorical_index), fast_mode=self.fast_mode),                
                nn.Sigmoid(),
                Reshape(-1, len(self.categorical_index), self.n_cond)
            )
        return phi_1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        M = []

        if len(self.numerical_index):
            num_x = x[:, self.numerical_index].float()
            num_sample_emb = self.phi_1['num'](num_x)
            M.append(num_sample_emb)

        if len(self.categorical_index):
            cat_x = x[:, self.categorical_index].long() + self.cat_offset
            cat_sample_emb = self.phi_1['cat']['mapper'](self.phi_1['cat']['embedder'](cat_x))
            M.append(cat_sample_emb)

        M = torch.cat(M, dim=1) # (b, n_col, n_cond)
        M = M.permute(0, 2, 1) # (b, n_cond, n_col)
        return M


class rODTConstruction(nn.Module):
    def __init__(self, n_cond: int, n_col: int, d: int) -> None:
        super().__init__()
        self.permutator = torch.rand(n_cond * n_col).argsort(-1)
        self.d = d

    def forward(self, M: torch.Tensor) -> torch.Tensor:
        b, _, _, embed_dim = M.shape
        return M.reshape(b, -1, embed_dim)[:, self.permutator, :].reshape(b, -1, self.d, embed_dim)


class rODTForestConstruction(nn.Module):
    def __init__(
        self,
        n_col: int,
        n_rodt: int,
        n_cond: int,
        n_estimator: int,
        n_head: int = 1,
        n_hidden: int = 128,
        n_forest: int = 100,
        dropout: float = 0.0,
        fast_mode: int = 64,
        device = torch.device('cuda')
    ) -> None:

        super().__init__()
        self.device = device
        self.n_estimator = n_estimator
        self.n_forest = n_forest
        self.n_rodt = n_rodt
        self.n_head = n_head
        self.n_hidden = n_hidden

        self.sample_without_replacement_eval = self.get_sample_without_replacement()

    def get_sample_without_replacement(self) -> torch.Tensor:
        return torch.rand(self.n_forest, self.n_rodt, device=self.device).argsort(-1)[:, :self.n_estimator]

    def forward(self, w, E) -> torch.Tensor:
        # w: (b, n_rodt, 1)
        # E: (b, n_rodt, n_hidden)
    
        sample_without_replacement = self.get_sample_without_replacement() if self.training else self.sample_without_replacement_eval

        w_prime = w[:, sample_without_replacement].softmax(-2) # (b, n_forest, n_rodt, 1)
        E_prime = E[:, sample_without_replacement].reshape(
            E.shape[0], self.n_forest, self.n_estimator, self.n_hidden
        ) # (b, n_forest, n_rodt, n_hidden)

        F = (w_prime * E_prime).sum(-2).reshape(
            E.shape[0], self.n_forest, self.n_hidden
        ) # (b, n_forest, n_hidden)
        return F


class rODTForestBagging(nn.Module):
    def __init__(self, n_hidden: int, dropout: float, n_class: int) -> None:
        super().__init__()
        self.phi_3 = nn.Sequential(
            nn.LayerNorm(n_hidden),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.LayerNorm(n_hidden),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, n_class)
        )

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        return self.phi_3(F) # (b, n_forest, n_class)


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scaling = self.head_dim ** -0.5
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.qw_proj = self.create_linear(embed_dim, embed_dim)
        self.kw_proj = self.create_linear(embed_dim, embed_dim)
        self.vw_proj = self.create_linear(embed_dim, embed_dim)
        self.ow_proj = self.create_linear(embed_dim, 1)

        self.qE_proj = self.create_linear(embed_dim, embed_dim)
        self.kE_proj = self.create_linear(embed_dim, embed_dim)
        self.vE_proj = self.create_linear(embed_dim, embed_dim)
        self.oE_proj = self.create_linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def create_linear(self, in_features: int, out_features: int):
        linear = nn.Linear(in_features, out_features)
        nn.init.xavier_uniform_(linear.weight, gain=1 / 2 ** 0.5)
        return linear

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        query: Tensor of shape (batch_size, tgt_len, embed_dim)
        key: Tensor of shape (batch_size, src_len, embed_dim)
        value: Tensor of shape (batch_size, src_len, embed_dim)
        attn_mask: Optional[Tensor] of shape (tgt_len, src_len) or (batch_size, tgt_len, src_len)
        """
        batch_size, tgt_len, _ = query.size()
        src_len = key.size(1)

        # Compute w
        qw = self.qw_proj(query)  # shape: (batch_size, tgt_len, embed_dim)
        kw = self.kw_proj(key)    # shape: (batch_size, src_len, embed_dim)
        vw = self.vw_proj(value)  # shape: (batch_size, src_len, 1)

        # Reshape into multihead format
        qw = qw.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, tgt_len, head_dim)
        kw = kw.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, head_dim)
        vw = vw.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, 1)

        attn_w = F.softmax(torch.matmul(qw, kw.transpose(-2, -1)) / self.scaling, dim=-1) # shape: (batch_size, num_heads, tgt_len, src_len)
        attn_w = self.dropout(attn_w) 
        attn_w_output = torch.matmul(attn_w, vw) # shape: (batch_size, num_heads, tgt_len, head_dim)

        # Reshape back to original dimensions
        attn_w_output = attn_w_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim) # shape: (batch_size, tgt_len, embed_dim)
        w_output = self.ow_proj(attn_w_output) # shape: (batch_size, tgt_len, 1)
        w_output = w_output.mean(1)

        # Compute w
        qE = self.qE_proj(query)  # shape: (batch_size, tgt_len, embed_dim)
        kE = self.kE_proj(key)    # shape: (batch_size, src_len, embed_dim)
        vE = self.vE_proj(value)  # shape: (batch_size, src_len, embed_dim)

        # Reshape into multihead format
        qE = qE.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, tgt_len, head_dim)
        kE = kE.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, head_dim)
        vE = vE.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, 1)

        attn_E = F.softmax(torch.matmul(qE, kE.transpose(-2, -1)) / self.scaling, dim=-1) # shape: (batch_size, num_heads, tgt_len, src_len)
        attn_E = self.dropout(attn_E) 
        attn_E_output = torch.matmul(attn_E, vE) # shape: (batch_size, num_heads, tgt_len, head_dim)

        # Reshape back to original dimensions
        attn_E_output = attn_E_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim) # shape: (batch_size, tgt_len, embed_dim)
        E_output = self.oE_proj(attn_E_output) # shape: (batch_size, tgt_len, embed_dim)
        E_output = E_output.mean(1)

        return w_output, E_output


class DOFEN(nn.Module):
    def __init__(
        self,
        category_column_count: List[int],
        n_class: int, 
        m: int = 16, 
        d: int = 4, 
        n_head: int = 1,
        n_forest: int = 100,
        n_hidden: int = 128,
        dropout: float = 0.0, 
        categorical_optimized: bool = False,
        fast_mode: int = 2048,
        use_bagging_loss: bool = False,
        device=torch.device('cuda'),
    ):
        super().__init__()

        self.device = device
        self.n_class = 1 if n_class == -1 else n_class
        self.is_rgr = True if n_class == -1 else False

        self.m = m
        self.d = d
        self.n_head = n_head
        self.n_forest = n_forest
        self.n_hidden = n_hidden
        self.dropout = dropout
        self.use_bagging_loss = use_bagging_loss

        self.n_cond = self.d * self.m
        self.n_col = len(category_column_count)
        self.n_rodt = self.n_cond * self.n_col // self.d
        self.n_estimator = max(2, int(self.n_col ** 0.5)) * self.n_cond // self.d

        self.condition_generation = ConditionGeneration(            
            category_column_count, 
            n_cond=self.n_cond, 
            categorical_optimized=categorical_optimized, 
            fast_mode=fast_mode,
        )

        self.proj = nn.Linear(1, self.n_hidden)
        self.rodt_construction = rODTConstruction(
            self.n_cond,
            self.n_col,
            self.d,
        )

        self.norm = nn.LayerNorm(n_hidden)
        self.attn = MultiheadAttention(
            embed_dim=self.n_hidden,
            num_heads=n_head,
            dropout=0.2,
        )

        self.rodt_forest_construction = rODTForestConstruction(
            self.n_col, 
            self.n_rodt, 
            self.n_cond,
            self.n_estimator,
            n_head=self.n_head, 
            n_hidden=self.n_hidden,
            n_forest=self.n_forest,
            dropout=self.dropout,
            fast_mode=fast_mode,
            device=self.device
        )
        self.rodt_forest_bagging = rODTForestBagging(
            self.n_hidden,
            self.dropout,
            self.n_class
        )

        self.ffn_dropout_rate = 0.1
        self.act = ReGLU()

    def compute_loss(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def evaluate(self, X: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
        M = self.condition_generation(x)           # (b, n_cond, n_col)
        M = M.unsqueeze(-1)                        # (b, n_cond, n_col, 1)
        M = self.proj(M)                           # (b, n_cond, n_col, n_hidden)
        O = self.rodt_construction(M)              # (b, n_rodt, d, n_hidden)
        O = O.reshape(-1, self.d, self.n_hidden)   # (b * n_rodt, d, n_hidden)
        O = self.norm(O)
        w, E = self.attn(O, O, O)

        w = w.reshape(-1, self.n_rodt, 1)
        E = E.reshape(-1, self.n_rodt, self.n_hidden)

        E_norm = E / E.norm(dim=-1, keepdim=True)
        #sim_loss = torch.einsum('bij,bjk->bik', E_norm, E_norm.permute(0, 2, 1))

        F = self.rodt_forest_construction(w, E) # (b, n_forest, n_hidden)
        y_hats = self.rodt_forest_bagging(F)    # (b, n_rodt, n_class)
        y_hat = y_hats.mean(1)                  # (b, n_class)

        if y is not None:
            loss = self.compute_loss(
                y_hats.permute(0, 2, 1) if not self.is_rgr else y_hats, 
                y.unsqueeze(-1).expand(-1, self.n_forest)
            ) #+  sim_loss.abs().mean()
            if self.n_forest > 1 and self.training and self.use_bagging_loss:
                loss += self.compute_loss(y_hat, y)
            return {'pred': y_hat, 'loss': loss}
        return {'pred': y_hat}

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)['pred']


class DOFENClassifier(BaseClassifier, DOFEN):
    def __init__(
        self,
        category_column_count: List[int],
        n_class: int, 
        m: int = 16,
        d: int = 4,
        n_head: int = 4,
        n_forest: int = 100,
        n_hidden: int = 16,
        dropout: float = 0.0, 
        categorical_optimized: bool = False,
        fast_mode: int = 2048,
        use_bagging_loss: bool = False,
        device=torch.device('cuda'),
    ) -> None:

        super().__init__(
            category_column_count=category_column_count, 
            n_class=n_class, 
            m=m, 
            d=d, 
            n_head=n_head,
            n_forest=n_forest,
            n_hidden=n_hidden,
            dropout=dropout, 
            categorical_optimized=categorical_optimized,
            fast_mode=fast_mode,
            use_bagging_loss=use_bagging_loss,
            device=device,
        )

In [86]:
args = Config()
set_random_seed()

data_dict = load_pkl_data(Path(args.data_dir, args.data_id, '0.pkl'))
n_class = data_dict['label_cat_count']
task = 'r' if n_class == -1 else 'c'

target_transform = args.target_transform if task == 'r' else False
train_X = data_dict['x_train']
train_y = data_dict['y_train' if not target_transform else 'y_train_transform']
valid_X = data_dict['x_val']
valid_y = data_dict['y_val' if not target_transform else 'y_val_transform']
test_X = data_dict['x_test']
test_y = data_dict['y_test' if not target_transform else 'y_test_transform']

data_args = {
    'n_feature': train_X.shape[1],
    'n_train': train_X.shape[0],
    'n_valid': valid_X.shape[0],
    'n_test': test_X.shape[0],
}

model_params = {
    'category_column_count': data_dict['col_cat_count'],
}
if n_class != -1:
    model_params['n_class'] = n_class


model = DOFENClassifier(**model_params)
eval_funs = cls_eval_funs
metrics = 'accuracy'

wandb = None
trainer = Trainer(
    model,
    batch_size=args.batch_size,
    n_epoch=args.n_epoch,
    eval_funs=eval_funs,
    metric=metrics,
    logger=wandb,
)
trainer.fit(
    train_X=train_X, train_y=train_y,
    valid_X=valid_X, valid_y=valid_y,
    test_X=test_X, test_y=test_y,
)

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 1/250 [00:01<05:34,  1.34s/it]

{'epoch': 1, 'valid_accuracy': 0.5237649262467806, 'best_valid_epoch': 1, 'best_valid_score': 0.5237649262467806, 'test_accuracy': 0.5192153321292394, 'selected_test_score': 0.5192153321292394}


  1%|          | 2/250 [00:02<05:34,  1.35s/it]

{'epoch': 2, 'valid_accuracy': 0.520135799578553, 'best_valid_epoch': 1, 'best_valid_score': 0.5237649262467806, 'test_accuracy': 0.5195665261890428, 'selected_test_score': 0.5192153321292394}


  1%|          | 3/250 [00:04<05:32,  1.35s/it]

{'epoch': 3, 'valid_accuracy': 0.5009365488176071, 'best_valid_epoch': 1, 'best_valid_score': 0.5237649262467806, 'test_accuracy': 0.4990467589805338, 'selected_test_score': 0.5192153321292394}


  2%|▏         | 4/250 [00:05<05:29,  1.34s/it]

{'epoch': 4, 'valid_accuracy': 0.49906345118239287, 'best_valid_epoch': 1, 'best_valid_score': 0.5237649262467806, 'test_accuracy': 0.5009532410194661, 'selected_test_score': 0.5192153321292394}


  2%|▏         | 5/250 [00:06<05:27,  1.34s/it]

{'epoch': 5, 'valid_accuracy': 0.5751580426129712, 'best_valid_epoch': 5, 'best_valid_score': 0.5751580426129712, 'test_accuracy': 0.5689845474613686, 'selected_test_score': 0.5689845474613686}


  2%|▏         | 6/250 [00:08<05:27,  1.34s/it]

{'epoch': 6, 'valid_accuracy': 0.49906345118239287, 'best_valid_epoch': 5, 'best_valid_score': 0.5751580426129712, 'test_accuracy': 0.5009532410194661, 'selected_test_score': 0.5689845474613686}


  3%|▎         | 7/250 [00:09<05:35,  1.38s/it]

{'epoch': 7, 'valid_accuracy': 0.5451884804495434, 'best_valid_epoch': 5, 'best_valid_score': 0.5751580426129712, 'test_accuracy': 0.5423439694962874, 'selected_test_score': 0.5689845474613686}


  3%|▎         | 8/250 [00:10<05:30,  1.37s/it]

{'epoch': 8, 'valid_accuracy': 0.5009365488176071, 'best_valid_epoch': 5, 'best_valid_score': 0.5751580426129712, 'test_accuracy': 0.4990467589805338, 'selected_test_score': 0.5689845474613686}


  4%|▎         | 9/250 [00:12<05:26,  1.35s/it]

{'epoch': 9, 'valid_accuracy': 0.501053617419808, 'best_valid_epoch': 5, 'best_valid_score': 0.5751580426129712, 'test_accuracy': 0.5030604053782862, 'selected_test_score': 0.5689845474613686}


  4%|▍         | 10/250 [00:13<05:24,  1.35s/it]

{'epoch': 10, 'valid_accuracy': 0.5009365488176071, 'best_valid_epoch': 5, 'best_valid_score': 0.5751580426129712, 'test_accuracy': 0.4990467589805338, 'selected_test_score': 0.5689845474613686}


  4%|▍         | 11/250 [00:14<05:22,  1.35s/it]

{'epoch': 11, 'valid_accuracy': 0.643291969093889, 'best_valid_epoch': 11, 'best_valid_score': 0.643291969093889, 'test_accuracy': 0.6426851294400964, 'selected_test_score': 0.6426851294400964}


  5%|▍         | 12/250 [00:16<05:20,  1.35s/it]

{'epoch': 12, 'valid_accuracy': 0.5009365488176071, 'best_valid_epoch': 11, 'best_valid_score': 0.643291969093889, 'test_accuracy': 0.4990467589805338, 'selected_test_score': 0.6426851294400964}


  5%|▌         | 13/250 [00:17<05:19,  1.35s/it]

{'epoch': 13, 'valid_accuracy': 0.6472723015687193, 'best_valid_epoch': 13, 'best_valid_score': 0.6472723015687193, 'test_accuracy': 0.6520670278948425, 'selected_test_score': 0.6520670278948425}


  6%|▌         | 14/250 [00:18<05:12,  1.33s/it]

{'epoch': 14, 'valid_accuracy': 0.5024584406462187, 'best_valid_epoch': 13, 'best_valid_score': 0.6472723015687193, 'test_accuracy': 0.5011539233393538, 'selected_test_score': 0.6520670278948425}


  6%|▌         | 15/250 [00:20<05:12,  1.33s/it]

{'epoch': 15, 'valid_accuracy': 0.5063217045188481, 'best_valid_epoch': 13, 'best_valid_score': 0.6472723015687193, 'test_accuracy': 0.5078266104756171, 'selected_test_score': 0.6520670278948425}


  6%|▋         | 16/250 [00:21<05:11,  1.33s/it]

{'epoch': 16, 'valid_accuracy': 0.6663544837274643, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.665914107967088, 'selected_test_score': 0.665914107967088}


  7%|▋         | 17/250 [00:22<05:09,  1.33s/it]

{'epoch': 17, 'valid_accuracy': 0.5585343011004449, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.5537326911499096, 'selected_test_score': 0.665914107967088}


  7%|▋         | 18/250 [00:24<05:09,  1.33s/it]

{'epoch': 18, 'valid_accuracy': 0.5009365488176071, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.4990467589805338, 'selected_test_score': 0.665914107967088}


  8%|▊         | 19/250 [00:25<05:13,  1.36s/it]

{'epoch': 19, 'valid_accuracy': 0.4998829313977991, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.5017559702990166, 'selected_test_score': 0.665914107967088}


  8%|▊         | 20/250 [00:26<05:10,  1.35s/it]

{'epoch': 20, 'valid_accuracy': 0.557363615078436, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.555639173188842, 'selected_test_score': 0.665914107967088}


  8%|▊         | 21/250 [00:28<05:08,  1.35s/it]

{'epoch': 21, 'valid_accuracy': 0.5009365488176071, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.4990467589805338, 'selected_test_score': 0.665914107967088}


  9%|▉         | 22/250 [00:29<05:05,  1.34s/it]

{'epoch': 22, 'valid_accuracy': 0.5009365488176071, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.4990467589805338, 'selected_test_score': 0.665914107967088}


  9%|▉         | 23/250 [00:30<05:03,  1.34s/it]

{'epoch': 23, 'valid_accuracy': 0.5924841957387029, 'best_valid_epoch': 16, 'best_valid_score': 0.6663544837274643, 'test_accuracy': 0.5859923740718442, 'selected_test_score': 0.665914107967088}


 10%|▉         | 24/250 [00:32<05:03,  1.34s/it]

{'epoch': 24, 'valid_accuracy': 0.6741980800749239, 'best_valid_epoch': 24, 'best_valid_score': 0.6741980800749239, 'test_accuracy': 0.6708809953843067, 'selected_test_score': 0.6708809953843067}


 10%|█         | 25/250 [00:33<05:02,  1.34s/it]

{'epoch': 25, 'valid_accuracy': 0.6023179583235776, 'best_valid_epoch': 24, 'best_valid_score': 0.6741980800749239, 'test_accuracy': 0.5957254665863937, 'selected_test_score': 0.6708809953843067}


 10%|█         | 26/250 [00:34<05:00,  1.34s/it]

{'epoch': 26, 'valid_accuracy': 0.6322875204870054, 'best_valid_epoch': 24, 'best_valid_score': 0.6741980800749239, 'test_accuracy': 0.6303933373469798, 'selected_test_score': 0.6708809953843067}


 11%|█         | 27/250 [00:36<04:59,  1.34s/it]

{'epoch': 27, 'valid_accuracy': 0.6164832591898852, 'best_valid_epoch': 24, 'best_valid_score': 0.6741980800749239, 'test_accuracy': 0.6111278346377684, 'selected_test_score': 0.6708809953843067}


 11%|█         | 28/250 [00:37<04:57,  1.34s/it]

{'epoch': 28, 'valid_accuracy': 0.7013579957855304, 'best_valid_epoch': 28, 'best_valid_score': 0.7013579957855304, 'test_accuracy': 0.6978727674091912, 'selected_test_score': 0.6978727674091912}


 12%|█▏        | 29/250 [00:38<04:55,  1.34s/it]

{'epoch': 29, 'valid_accuracy': 0.6959728400842894, 'best_valid_epoch': 28, 'best_valid_score': 0.7013579957855304, 'test_accuracy': 0.6945113385510736, 'selected_test_score': 0.6978727674091912}


 12%|█▏        | 30/250 [00:40<04:59,  1.36s/it]

{'epoch': 30, 'valid_accuracy': 0.6511355654413487, 'best_valid_epoch': 28, 'best_valid_score': 0.7013579957855304, 'test_accuracy': 0.6495083283162754, 'selected_test_score': 0.6978727674091912}


 12%|█▏        | 31/250 [00:41<04:52,  1.33s/it]

{'epoch': 31, 'valid_accuracy': 0.696324045890892, 'best_valid_epoch': 28, 'best_valid_score': 0.7013579957855304, 'test_accuracy': 0.6951635560907085, 'selected_test_score': 0.6978727674091912}


 13%|█▎        | 32/250 [00:43<04:52,  1.34s/it]

{'epoch': 32, 'valid_accuracy': 0.6690470615780848, 'best_valid_epoch': 28, 'best_valid_score': 0.7013579957855304, 'test_accuracy': 0.6666666666666666, 'selected_test_score': 0.6978727674091912}


 13%|█▎        | 33/250 [00:44<04:51,  1.34s/it]

{'epoch': 33, 'valid_accuracy': 0.7244205104191056, 'best_valid_epoch': 33, 'best_valid_score': 0.7244205104191056, 'test_accuracy': 0.7215031105759583, 'selected_test_score': 0.7215031105759583}


 14%|█▎        | 34/250 [00:45<04:49,  1.34s/it]

{'epoch': 34, 'valid_accuracy': 0.7384687426832124, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7354003612281758, 'selected_test_score': 0.7354003612281758}


 14%|█▍        | 35/250 [00:47<04:47,  1.34s/it]

{'epoch': 35, 'valid_accuracy': 0.6891828611566377, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.6871362632952037, 'selected_test_score': 0.7354003612281758}


 14%|█▍        | 36/250 [00:48<04:46,  1.34s/it]

{'epoch': 36, 'valid_accuracy': 0.7142355420276282, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.713726670680313, 'selected_test_score': 0.7354003612281758}


 15%|█▍        | 37/250 [00:49<04:45,  1.34s/it]

{'epoch': 37, 'valid_accuracy': 0.7055724654647624, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7019365843869155, 'selected_test_score': 0.7354003612281758}


 15%|█▌        | 38/250 [00:51<04:45,  1.34s/it]

{'epoch': 38, 'valid_accuracy': 0.7045188480449543, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.6988761790086293, 'selected_test_score': 0.7354003612281758}


 16%|█▌        | 39/250 [00:52<04:44,  1.35s/it]

{'epoch': 39, 'valid_accuracy': 0.7343713416061812, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7335942203491872, 'selected_test_score': 0.7354003612281758}


 16%|█▌        | 40/250 [00:53<04:42,  1.34s/it]

{'epoch': 40, 'valid_accuracy': 0.7349566846171857, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7340959261489063, 'selected_test_score': 0.7354003612281758}


 16%|█▋        | 41/250 [00:55<04:39,  1.34s/it]

{'epoch': 41, 'valid_accuracy': 0.7291032545071412, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7299819385912101, 'selected_test_score': 0.7354003612281758}


 17%|█▋        | 42/250 [00:56<04:42,  1.36s/it]

{'epoch': 42, 'valid_accuracy': 0.7383516740810114, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7368051374673891, 'selected_test_score': 0.7354003612281758}


 17%|█▋        | 43/250 [00:57<04:40,  1.36s/it]

{'epoch': 43, 'valid_accuracy': 0.715406228049637, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7158840056191049, 'selected_test_score': 0.7354003612281758}


 18%|█▊        | 44/250 [00:59<04:38,  1.35s/it]

{'epoch': 44, 'valid_accuracy': 0.7183329431046593, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7185430463576159, 'selected_test_score': 0.7354003612281758}


 18%|█▊        | 45/250 [01:00<04:36,  1.35s/it]

{'epoch': 45, 'valid_accuracy': 0.7308592835401545, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7317880794701986, 'selected_test_score': 0.7354003612281758}


 18%|█▊        | 46/250 [01:01<04:33,  1.34s/it]

{'epoch': 46, 'valid_accuracy': 0.733551861390775, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7346478025285972, 'selected_test_score': 0.7354003612281758}


 19%|█▉        | 47/250 [01:03<04:31,  1.34s/it]

{'epoch': 47, 'valid_accuracy': 0.7348396160149848, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7360525787678106, 'selected_test_score': 0.7354003612281758}


 19%|█▉        | 48/250 [01:04<04:26,  1.32s/it]

{'epoch': 48, 'valid_accuracy': 0.7015921329899321, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.6978225968292193, 'selected_test_score': 0.7354003612281758}


 20%|█▉        | 49/250 [01:05<04:26,  1.33s/it]

{'epoch': 49, 'valid_accuracy': 0.7371809880590026, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7370559903672487, 'selected_test_score': 0.7354003612281758}


 20%|██        | 50/250 [01:07<04:26,  1.33s/it]

{'epoch': 50, 'valid_accuracy': 0.7151720908452353, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7161348585189645, 'selected_test_score': 0.7354003612281758}


 20%|██        | 51/250 [01:08<04:26,  1.34s/it]

{'epoch': 51, 'valid_accuracy': 0.7382346054788106, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.737407184427052, 'selected_test_score': 0.7354003612281758}


 21%|██        | 52/250 [01:09<04:25,  1.34s/it]

{'epoch': 52, 'valid_accuracy': 0.734605478810583, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.7371563315271924, 'selected_test_score': 0.7354003612281758}


 21%|██        | 53/250 [01:11<04:24,  1.34s/it]

{'epoch': 53, 'valid_accuracy': 0.7347225474127839, 'best_valid_epoch': 34, 'best_valid_score': 0.7384687426832124, 'test_accuracy': 0.736855308047361, 'selected_test_score': 0.7354003612281758}


 22%|██▏       | 54/250 [01:12<04:10,  1.28s/it]

{'epoch': 54, 'valid_accuracy': 0.7401077031140249, 'best_valid_epoch': 54, 'best_valid_score': 0.7401077031140249, 'test_accuracy': 0.7381095725466587, 'selected_test_score': 0.7381095725466587}


 22%|██▏       | 55/250 [01:13<03:37,  1.11s/it]

{'epoch': 55, 'valid_accuracy': 0.7319129009599625, 'best_valid_epoch': 54, 'best_valid_score': 0.7401077031140249, 'test_accuracy': 0.732139273530002, 'selected_test_score': 0.7381095725466587}


 22%|██▏       | 56/250 [01:13<03:15,  1.01s/it]

{'epoch': 56, 'valid_accuracy': 0.7432685553734488, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7411699779249448, 'selected_test_score': 0.7411699779249448}


 23%|██▎       | 57/250 [01:14<03:00,  1.07it/s]

{'epoch': 57, 'valid_accuracy': 0.7230156871926949, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7250150511739916, 'selected_test_score': 0.7411699779249448}


 23%|██▎       | 58/250 [01:15<02:48,  1.14it/s]

{'epoch': 58, 'valid_accuracy': 0.7402247717162257, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7397150311057595, 'selected_test_score': 0.7411699779249448}


 24%|██▎       | 59/250 [01:16<02:38,  1.20it/s]

{'epoch': 59, 'valid_accuracy': 0.7410442519316319, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7401163957455348, 'selected_test_score': 0.7411699779249448}


 24%|██▍       | 60/250 [01:16<02:37,  1.20it/s]

{'epoch': 60, 'valid_accuracy': 0.7410442519316319, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7400662251655629, 'selected_test_score': 0.7411699779249448}


 24%|██▍       | 61/250 [01:18<03:05,  1.02it/s]

{'epoch': 61, 'valid_accuracy': 0.7309763521423555, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7325908087497491, 'selected_test_score': 0.7411699779249448}


 25%|██▍       | 62/250 [01:19<03:24,  1.09s/it]

{'epoch': 62, 'valid_accuracy': 0.7409271833294311, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7411699779249448, 'selected_test_score': 0.7411699779249448}


 25%|██▌       | 63/250 [01:20<03:37,  1.16s/it]

{'epoch': 63, 'valid_accuracy': 0.7415125263404355, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.740567930965282, 'selected_test_score': 0.7411699779249448}


 26%|██▌       | 64/250 [01:22<03:46,  1.22s/it]

{'epoch': 64, 'valid_accuracy': 0.7369468508546008, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7381597431266306, 'selected_test_score': 0.7411699779249448}


 26%|██▌       | 65/250 [01:23<03:51,  1.25s/it]

{'epoch': 65, 'valid_accuracy': 0.7374151252634044, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7377583784868553, 'selected_test_score': 0.7411699779249448}


 26%|██▋       | 66/250 [01:24<03:59,  1.30s/it]

{'epoch': 66, 'valid_accuracy': 0.7425661437602435, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7411699779249448, 'selected_test_score': 0.7411699779249448}


 27%|██▋       | 67/250 [01:26<04:00,  1.32s/it]

{'epoch': 67, 'valid_accuracy': 0.7417466635448373, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7412703190848886, 'selected_test_score': 0.7411699779249448}


 27%|██▋       | 68/250 [01:27<04:01,  1.33s/it]

{'epoch': 68, 'valid_accuracy': 0.7410442519316319, 'best_valid_epoch': 56, 'best_valid_score': 0.7432685553734488, 'test_accuracy': 0.7425245835841863, 'selected_test_score': 0.7411699779249448}


 28%|██▊       | 69/250 [01:29<04:00,  1.33s/it]

{'epoch': 69, 'valid_accuracy': 0.7436197611800515, 'best_valid_epoch': 69, 'best_valid_score': 0.7436197611800515, 'test_accuracy': 0.7425747541641582, 'selected_test_score': 0.7425747541641582}


 28%|██▊       | 70/250 [01:30<03:59,  1.33s/it]

{'epoch': 70, 'valid_accuracy': 0.7426832123624444, 'best_valid_epoch': 69, 'best_valid_score': 0.7436197611800515, 'test_accuracy': 0.7422737306843267, 'selected_test_score': 0.7425747541641582}


 28%|██▊       | 71/250 [01:31<03:58,  1.33s/it]

{'epoch': 71, 'valid_accuracy': 0.743034418169047, 'best_valid_epoch': 69, 'best_valid_score': 0.7436197611800515, 'test_accuracy': 0.7425245835841863, 'selected_test_score': 0.7425747541641582}


 29%|██▉       | 72/250 [01:33<03:58,  1.34s/it]

{'epoch': 72, 'valid_accuracy': 0.738819948489815, 'best_valid_epoch': 69, 'best_valid_score': 0.7436197611800515, 'test_accuracy': 0.7404675898053381, 'selected_test_score': 0.7425747541641582}


 29%|██▉       | 73/250 [01:34<03:57,  1.34s/it]

{'epoch': 73, 'valid_accuracy': 0.7436197611800515, 'best_valid_epoch': 69, 'best_valid_score': 0.7436197611800515, 'test_accuracy': 0.7430262893839052, 'selected_test_score': 0.7425747541641582}


 30%|██▉       | 74/250 [01:35<03:54,  1.34s/it]

{'epoch': 74, 'valid_accuracy': 0.7438538983844533, 'best_valid_epoch': 74, 'best_valid_score': 0.7438538983844533, 'test_accuracy': 0.7431266305438491, 'selected_test_score': 0.7431266305438491}


 30%|███       | 75/250 [01:37<03:54,  1.34s/it]

{'epoch': 75, 'valid_accuracy': 0.7468976820416764, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.745283965482641, 'selected_test_score': 0.745283965482641}


 30%|███       | 76/250 [01:38<03:53,  1.34s/it]

{'epoch': 76, 'valid_accuracy': 0.7424490751580426, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7437286775035119, 'selected_test_score': 0.745283965482641}


 31%|███       | 77/250 [01:39<03:51,  1.34s/it]

{'epoch': 77, 'valid_accuracy': 0.7459611332240693, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7452337949026691, 'selected_test_score': 0.745283965482641}


 31%|███       | 78/250 [01:41<03:55,  1.37s/it]

{'epoch': 78, 'valid_accuracy': 0.7429173495668462, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7437788480834838, 'selected_test_score': 0.745283965482641}


 32%|███▏      | 79/250 [01:42<03:52,  1.36s/it]

{'epoch': 79, 'valid_accuracy': 0.746312339030672, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7446317479430062, 'selected_test_score': 0.745283965482641}


 32%|███▏      | 80/250 [01:43<03:51,  1.36s/it]

{'epoch': 80, 'valid_accuracy': 0.743034418169047, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7439795304033715, 'selected_test_score': 0.745283965482641}


 32%|███▏      | 81/250 [01:45<03:48,  1.35s/it]

{'epoch': 81, 'valid_accuracy': 0.7454928588152657, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7455348183825005, 'selected_test_score': 0.745283965482641}


 33%|███▎      | 82/250 [01:46<03:46,  1.35s/it]

{'epoch': 82, 'valid_accuracy': 0.7435026925778506, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7446317479430062, 'selected_test_score': 0.745283965482641}


 33%|███▎      | 83/250 [01:47<03:44,  1.35s/it]

{'epoch': 83, 'valid_accuracy': 0.7467806134394755, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7457355007023881, 'selected_test_score': 0.745283965482641}


 34%|███▎      | 84/250 [01:49<03:43,  1.35s/it]

{'epoch': 84, 'valid_accuracy': 0.7460782018262702, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.745835841862332, 'selected_test_score': 0.745283965482641}


 34%|███▍      | 85/250 [01:50<03:42,  1.35s/it]

{'epoch': 85, 'valid_accuracy': 0.7459611332240693, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7462372065021071, 'selected_test_score': 0.745283965482641}


 34%|███▍      | 86/250 [01:51<03:40,  1.34s/it]

{'epoch': 86, 'valid_accuracy': 0.7467806134394755, 'best_valid_epoch': 75, 'best_valid_score': 0.7468976820416764, 'test_accuracy': 0.7469395946217138, 'selected_test_score': 0.745283965482641}


 35%|███▍      | 87/250 [01:53<03:38,  1.34s/it]

{'epoch': 87, 'valid_accuracy': 0.748419573870288, 'best_valid_epoch': 87, 'best_valid_score': 0.748419573870288, 'test_accuracy': 0.7470399357816576, 'selected_test_score': 0.7470399357816576}


 35%|███▌      | 88/250 [01:54<03:36,  1.34s/it]

{'epoch': 88, 'valid_accuracy': 0.7478342308592836, 'best_valid_epoch': 87, 'best_valid_score': 0.748419573870288, 'test_accuracy': 0.7471402769416015, 'selected_test_score': 0.7470399357816576}


 36%|███▌      | 89/250 [01:55<03:35,  1.34s/it]

{'epoch': 89, 'valid_accuracy': 0.7464294076328729, 'best_valid_epoch': 87, 'best_valid_score': 0.748419573870288, 'test_accuracy': 0.7476921533212924, 'selected_test_score': 0.7470399357816576}


 36%|███▌      | 90/250 [01:57<03:39,  1.37s/it]

{'epoch': 90, 'valid_accuracy': 0.7477171622570826, 'best_valid_epoch': 87, 'best_valid_score': 0.748419573870288, 'test_accuracy': 0.7475918121613486, 'selected_test_score': 0.7470399357816576}


 36%|███▋      | 91/250 [01:58<03:34,  1.35s/it]

{'epoch': 91, 'valid_accuracy': 0.748419573870288, 'best_valid_epoch': 87, 'best_valid_score': 0.748419573870288, 'test_accuracy': 0.7482942002809553, 'selected_test_score': 0.7470399357816576}


 37%|███▋      | 92/250 [02:00<03:32,  1.35s/it]

{'epoch': 92, 'valid_accuracy': 0.7500585343011005, 'best_valid_epoch': 92, 'best_valid_score': 0.7500585343011005, 'test_accuracy': 0.7488460766606462, 'selected_test_score': 0.7488460766606462}


 37%|███▋      | 93/250 [02:01<03:30,  1.34s/it]

{'epoch': 93, 'valid_accuracy': 0.7490049168812924, 'best_valid_epoch': 92, 'best_valid_score': 0.7500585343011005, 'test_accuracy': 0.7472907886815171, 'selected_test_score': 0.7488460766606462}


 38%|███▊      | 94/250 [02:02<03:28,  1.34s/it]

{'epoch': 94, 'valid_accuracy': 0.7487707796768907, 'best_valid_epoch': 92, 'best_valid_score': 0.7500585343011005, 'test_accuracy': 0.7483443708609272, 'selected_test_score': 0.7488460766606462}


 38%|███▊      | 95/250 [02:04<03:27,  1.34s/it]

{'epoch': 95, 'valid_accuracy': 0.7486537110746898, 'best_valid_epoch': 92, 'best_valid_score': 0.7500585343011005, 'test_accuracy': 0.7486955649207305, 'selected_test_score': 0.7488460766606462}


 38%|███▊      | 96/250 [02:05<03:26,  1.34s/it]

{'epoch': 96, 'valid_accuracy': 0.7488878482790915, 'best_valid_epoch': 92, 'best_valid_score': 0.7500585343011005, 'test_accuracy': 0.7487959060806743, 'selected_test_score': 0.7488460766606462}


 39%|███▉      | 97/250 [02:06<03:25,  1.34s/it]

{'epoch': 97, 'valid_accuracy': 0.7492390540856942, 'best_valid_epoch': 92, 'best_valid_score': 0.7500585343011005, 'test_accuracy': 0.7502508528998595, 'selected_test_score': 0.7488460766606462}


 39%|███▉      | 98/250 [02:08<03:24,  1.34s/it]

{'epoch': 98, 'valid_accuracy': 0.7509950831187076, 'best_valid_epoch': 98, 'best_valid_score': 0.7509950831187076, 'test_accuracy': 0.7508027292795505, 'selected_test_score': 0.7508027292795505}


 40%|███▉      | 99/250 [02:09<03:22,  1.34s/it]

{'epoch': 99, 'valid_accuracy': 0.7508780145165067, 'best_valid_epoch': 98, 'best_valid_score': 0.7509950831187076, 'test_accuracy': 0.7511037527593819, 'selected_test_score': 0.7508027292795505}


 40%|████      | 100/250 [02:10<03:20,  1.34s/it]

{'epoch': 100, 'valid_accuracy': 0.7523999063451182, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7518061408789886, 'selected_test_score': 0.7518061408789886}


 40%|████      | 101/250 [02:12<03:19,  1.34s/it]

{'epoch': 101, 'valid_accuracy': 0.7495902598922969, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7513044350792695, 'selected_test_score': 0.7518061408789886}


 41%|████      | 102/250 [02:13<03:23,  1.37s/it]

{'epoch': 102, 'valid_accuracy': 0.7498243970966987, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7492976118803933, 'selected_test_score': 0.7518061408789886}


 41%|████      | 103/250 [02:14<03:20,  1.37s/it]

{'epoch': 103, 'valid_accuracy': 0.7521657691407164, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7525085289985952, 'selected_test_score': 0.7518061408789886}


 42%|████▏     | 104/250 [02:16<03:18,  1.36s/it]

{'epoch': 104, 'valid_accuracy': 0.7522828377429174, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7514549468191852, 'selected_test_score': 0.7518061408789886}


 42%|████▏     | 105/250 [02:17<03:15,  1.35s/it]

{'epoch': 105, 'valid_accuracy': 0.7507609459143058, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7498996588400562, 'selected_test_score': 0.7518061408789886}


 42%|████▏     | 106/250 [02:18<03:13,  1.35s/it]

{'epoch': 106, 'valid_accuracy': 0.751580426129712, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7526088701585391, 'selected_test_score': 0.7518061408789886}


 43%|████▎     | 107/250 [02:20<03:12,  1.35s/it]

{'epoch': 107, 'valid_accuracy': 0.7520487005385156, 'best_valid_epoch': 100, 'best_valid_score': 0.7523999063451182, 'test_accuracy': 0.7522075055187638, 'selected_test_score': 0.7518061408789886}


 43%|████▎     | 108/250 [02:21<03:09,  1.33s/it]

{'epoch': 108, 'valid_accuracy': 0.7528681807539218, 'best_valid_epoch': 108, 'best_valid_score': 0.7528681807539218, 'test_accuracy': 0.7533112582781457, 'selected_test_score': 0.7533112582781457}


 44%|████▎     | 109/250 [02:22<03:08,  1.34s/it]

{'epoch': 109, 'valid_accuracy': 0.7519316319363146, 'best_valid_epoch': 108, 'best_valid_score': 0.7528681807539218, 'test_accuracy': 0.753110575958258, 'selected_test_score': 0.7533112582781457}


 44%|████▍     | 110/250 [02:24<03:08,  1.34s/it]

{'epoch': 110, 'valid_accuracy': 0.7571997190353548, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7581276339554486, 'selected_test_score': 0.7581276339554486}


 44%|████▍     | 111/250 [02:25<03:06,  1.34s/it]

{'epoch': 111, 'valid_accuracy': 0.7539217981737298, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7544651816174995, 'selected_test_score': 0.7581276339554486}


 45%|████▍     | 112/250 [02:26<03:04,  1.34s/it]

{'epoch': 112, 'valid_accuracy': 0.7542730039803325, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7541139875576962, 'selected_test_score': 0.7581276339554486}


 45%|████▌     | 113/250 [02:28<03:02,  1.34s/it]

{'epoch': 113, 'valid_accuracy': 0.7522828377429174, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7512542644992977, 'selected_test_score': 0.7581276339554486}


 46%|████▌     | 114/250 [02:29<03:06,  1.37s/it]

{'epoch': 114, 'valid_accuracy': 0.7546242097869351, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7545153521974715, 'selected_test_score': 0.7581276339554486}


 46%|████▌     | 115/250 [02:31<03:03,  1.36s/it]

{'epoch': 115, 'valid_accuracy': 0.7504097401077031, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7477423239012643, 'selected_test_score': 0.7581276339554486}


 46%|████▋     | 116/250 [02:32<03:01,  1.36s/it]

{'epoch': 116, 'valid_accuracy': 0.756848513228752, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7609371864338752, 'selected_test_score': 0.7581276339554486}


 47%|████▋     | 117/250 [02:33<02:59,  1.35s/it]

{'epoch': 117, 'valid_accuracy': 0.7552095527979396, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7564720048163757, 'selected_test_score': 0.7581276339554486}


 47%|████▋     | 118/250 [02:35<02:57,  1.34s/it]

{'epoch': 118, 'valid_accuracy': 0.7562631702177476, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7562713224964881, 'selected_test_score': 0.7581276339554486}


 48%|████▊     | 119/250 [02:36<02:56,  1.35s/it]

{'epoch': 119, 'valid_accuracy': 0.756965581830953, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7560706401766004, 'selected_test_score': 0.7581276339554486}


 48%|████▊     | 120/250 [02:37<02:55,  1.35s/it]

{'epoch': 120, 'valid_accuracy': 0.7538047295715289, 'best_valid_epoch': 110, 'best_valid_score': 0.7571997190353548, 'test_accuracy': 0.7534617700180614, 'selected_test_score': 0.7581276339554486}


 48%|████▊     | 121/250 [02:39<02:53,  1.35s/it]

{'epoch': 121, 'valid_accuracy': 0.7574338562397565, 'best_valid_epoch': 121, 'best_valid_score': 0.7574338562397565, 'test_accuracy': 0.7572245635159542, 'selected_test_score': 0.7572245635159542}


 49%|████▉     | 122/250 [02:40<02:52,  1.34s/it]

{'epoch': 122, 'valid_accuracy': 0.7574338562397565, 'best_valid_epoch': 121, 'best_valid_score': 0.7574338562397565, 'test_accuracy': 0.7581778045354204, 'selected_test_score': 0.7572245635159542}


 49%|████▉     | 123/250 [02:41<02:50,  1.34s/it]

{'epoch': 123, 'valid_accuracy': 0.755794895808944, 'best_valid_epoch': 121, 'best_valid_score': 0.7574338562397565, 'test_accuracy': 0.7561208107565723, 'selected_test_score': 0.7572245635159542}


 50%|████▉     | 124/250 [02:43<02:48,  1.34s/it]

{'epoch': 124, 'valid_accuracy': 0.7531023179583236, 'best_valid_epoch': 121, 'best_valid_score': 0.7574338562397565, 'test_accuracy': 0.7529098936383705, 'selected_test_score': 0.7572245635159542}


 50%|█████     | 125/250 [02:44<02:45,  1.33s/it]

{'epoch': 125, 'valid_accuracy': 0.7586045422617654, 'best_valid_epoch': 125, 'best_valid_score': 0.7586045422617654, 'test_accuracy': 0.7593317278747742, 'selected_test_score': 0.7593317278747742}


 50%|█████     | 126/250 [02:45<02:48,  1.36s/it]

{'epoch': 126, 'valid_accuracy': 0.7589557480683681, 'best_valid_epoch': 126, 'best_valid_score': 0.7589557480683681, 'test_accuracy': 0.7607866746939594, 'selected_test_score': 0.7607866746939594}


 51%|█████     | 127/250 [02:47<02:46,  1.35s/it]

{'epoch': 127, 'valid_accuracy': 0.7610629829079841, 'best_valid_epoch': 127, 'best_valid_score': 0.7610629829079841, 'test_accuracy': 0.7652016857314871, 'selected_test_score': 0.7652016857314871}


 51%|█████     | 128/250 [02:48<02:44,  1.35s/it]

{'epoch': 128, 'valid_accuracy': 0.7594240224771717, 'best_valid_epoch': 127, 'best_valid_score': 0.7610629829079841, 'test_accuracy': 0.7610375275938189, 'selected_test_score': 0.7652016857314871}


 52%|█████▏    | 129/250 [02:49<02:42,  1.35s/it]

{'epoch': 129, 'valid_accuracy': 0.7561461016155467, 'best_valid_epoch': 127, 'best_valid_score': 0.7610629829079841, 'test_accuracy': 0.7561208107565723, 'selected_test_score': 0.7652016857314871}


 52%|█████▏    | 130/250 [02:51<02:40,  1.34s/it]

{'epoch': 130, 'valid_accuracy': 0.760126434090377, 'best_valid_epoch': 127, 'best_valid_score': 0.7610629829079841, 'test_accuracy': 0.7608870158539033, 'selected_test_score': 0.7652016857314871}


 52%|█████▏    | 131/250 [02:52<02:39,  1.34s/it]

{'epoch': 131, 'valid_accuracy': 0.7584874736595645, 'best_valid_epoch': 127, 'best_valid_score': 0.7610629829079841, 'test_accuracy': 0.7581276339554486, 'selected_test_score': 0.7652016857314871}


 53%|█████▎    | 132/250 [02:53<02:38,  1.35s/it]

{'epoch': 132, 'valid_accuracy': 0.7616483259189886, 'best_valid_epoch': 132, 'best_valid_score': 0.7616483259189886, 'test_accuracy': 0.7615894039735099, 'selected_test_score': 0.7615894039735099}


 53%|█████▎    | 133/250 [02:55<02:37,  1.34s/it]

{'epoch': 133, 'valid_accuracy': 0.7618824631233903, 'best_valid_epoch': 133, 'best_valid_score': 0.7618824631233903, 'test_accuracy': 0.7618402568733694, 'selected_test_score': 0.7618402568733694}


 54%|█████▎    | 134/250 [02:56<02:35,  1.34s/it]

{'epoch': 134, 'valid_accuracy': 0.7577850620463592, 'best_valid_epoch': 133, 'best_valid_score': 0.7618824631233903, 'test_accuracy': 0.7577764398956452, 'selected_test_score': 0.7618402568733694}


 54%|█████▍    | 135/250 [02:57<02:34,  1.34s/it]

{'epoch': 135, 'valid_accuracy': 0.7624678061343948, 'best_valid_epoch': 135, 'best_valid_score': 0.7624678061343948, 'test_accuracy': 0.762592815572948, 'selected_test_score': 0.762592815572948}


 54%|█████▍    | 136/250 [02:59<02:32,  1.34s/it]

{'epoch': 136, 'valid_accuracy': 0.7618824631233903, 'best_valid_epoch': 135, 'best_valid_score': 0.7624678061343948, 'test_accuracy': 0.7627934978928357, 'selected_test_score': 0.762592815572948}


 55%|█████▍    | 137/250 [03:00<02:31,  1.34s/it]

{'epoch': 137, 'valid_accuracy': 0.7576679934441582, 'best_valid_epoch': 135, 'best_valid_score': 0.7624678061343948, 'test_accuracy': 0.7578266104756171, 'selected_test_score': 0.762592815572948}


 55%|█████▌    | 138/250 [03:01<02:32,  1.36s/it]

{'epoch': 138, 'valid_accuracy': 0.7635214235542027, 'best_valid_epoch': 138, 'best_valid_score': 0.7635214235542027, 'test_accuracy': 0.764699979931768, 'selected_test_score': 0.764699979931768}


 56%|█████▌    | 139/250 [03:03<02:30,  1.35s/it]

{'epoch': 139, 'valid_accuracy': 0.753687660969328, 'best_valid_epoch': 138, 'best_valid_score': 0.7635214235542027, 'test_accuracy': 0.7533112582781457, 'selected_test_score': 0.764699979931768}


 56%|█████▌    | 140/250 [03:04<02:28,  1.35s/it]

{'epoch': 140, 'valid_accuracy': 0.762116600327792, 'best_valid_epoch': 138, 'best_valid_score': 0.7635214235542027, 'test_accuracy': 0.7623921332530604, 'selected_test_score': 0.764699979931768}


 56%|█████▋    | 141/250 [03:05<02:26,  1.35s/it]

{'epoch': 141, 'valid_accuracy': 0.7649262467806135, 'best_valid_epoch': 141, 'best_valid_score': 0.7649262467806135, 'test_accuracy': 0.7650010034115995, 'selected_test_score': 0.7650010034115995}


 57%|█████▋    | 142/250 [03:07<02:23,  1.33s/it]

{'epoch': 142, 'valid_accuracy': 0.7545071411847343, 'best_valid_epoch': 141, 'best_valid_score': 0.7649262467806135, 'test_accuracy': 0.754164158137668, 'selected_test_score': 0.7650010034115995}


 57%|█████▋    | 143/250 [03:08<02:22,  1.33s/it]

{'epoch': 143, 'valid_accuracy': 0.7637555607586045, 'best_valid_epoch': 141, 'best_valid_score': 0.7649262467806135, 'test_accuracy': 0.7645996387718242, 'selected_test_score': 0.7650010034115995}


 58%|█████▊    | 144/250 [03:09<02:21,  1.33s/it]

{'epoch': 144, 'valid_accuracy': 0.7657457269960196, 'best_valid_epoch': 144, 'best_valid_score': 0.7657457269960196, 'test_accuracy': 0.7682620911097733, 'selected_test_score': 0.7682620911097733}


 58%|█████▊    | 145/250 [03:11<02:20,  1.34s/it]

{'epoch': 145, 'valid_accuracy': 0.7658627955982206, 'best_valid_epoch': 145, 'best_valid_score': 0.7658627955982206, 'test_accuracy': 0.7653521974714028, 'selected_test_score': 0.7653521974714028}


 58%|█████▊    | 146/250 [03:12<02:19,  1.34s/it]

{'epoch': 146, 'valid_accuracy': 0.7596581596815734, 'best_valid_epoch': 145, 'best_valid_score': 0.7658627955982206, 'test_accuracy': 0.759432069034718, 'selected_test_score': 0.7653521974714028}


 59%|█████▉    | 147/250 [03:13<02:18,  1.34s/it]

{'epoch': 147, 'valid_accuracy': 0.7665652072114258, 'best_valid_epoch': 147, 'best_valid_score': 0.7665652072114258, 'test_accuracy': 0.7696166967690147, 'selected_test_score': 0.7696166967690147}


 59%|█████▉    | 148/250 [03:15<02:16,  1.34s/it]

{'epoch': 148, 'valid_accuracy': 0.7556778272067431, 'best_valid_epoch': 147, 'best_valid_score': 0.7665652072114258, 'test_accuracy': 0.7559201284366848, 'selected_test_score': 0.7696166967690147}


 60%|█████▉    | 149/250 [03:16<02:14,  1.34s/it]

{'epoch': 149, 'valid_accuracy': 0.7670334816202294, 'best_valid_epoch': 149, 'best_valid_score': 0.7670334816202294, 'test_accuracy': 0.7658539032711218, 'selected_test_score': 0.7658539032711218}


 60%|██████    | 150/250 [03:18<02:15,  1.36s/it]

{'epoch': 150, 'valid_accuracy': 0.7605947084991805, 'best_valid_epoch': 149, 'best_valid_score': 0.7670334816202294, 'test_accuracy': 0.7612883804936785, 'selected_test_score': 0.7658539032711218}


 60%|██████    | 151/250 [03:19<02:14,  1.36s/it]

{'epoch': 151, 'valid_accuracy': 0.7679700304378366, 'best_valid_epoch': 151, 'best_valid_score': 0.7679700304378366, 'test_accuracy': 0.768362432269717, 'selected_test_score': 0.768362432269717}


 61%|██████    | 152/250 [03:20<02:12,  1.35s/it]

{'epoch': 152, 'valid_accuracy': 0.7693748536642473, 'best_valid_epoch': 152, 'best_valid_score': 0.7693748536642473, 'test_accuracy': 0.7702689143086494, 'selected_test_score': 0.7702689143086494}


 61%|██████    | 153/250 [03:22<02:10,  1.35s/it]

{'epoch': 153, 'valid_accuracy': 0.7548583469913369, 'best_valid_epoch': 152, 'best_valid_score': 0.7693748536642473, 'test_accuracy': 0.7551173991571343, 'selected_test_score': 0.7702689143086494}


 62%|██████▏   | 154/250 [03:23<02:09,  1.34s/it]

{'epoch': 154, 'valid_accuracy': 0.7690236478576445, 'best_valid_epoch': 152, 'best_valid_score': 0.7693748536642473, 'test_accuracy': 0.7702187437286775, 'selected_test_score': 0.7702689143086494}


 62%|██████▏   | 155/250 [03:24<02:07,  1.34s/it]

{'epoch': 155, 'valid_accuracy': 0.7689065792554437, 'best_valid_epoch': 152, 'best_valid_score': 0.7693748536642473, 'test_accuracy': 0.7679108970499698, 'selected_test_score': 0.7702689143086494}


 62%|██████▏   | 156/250 [03:26<02:05,  1.34s/it]

{'epoch': 156, 'valid_accuracy': 0.7685553734488411, 'best_valid_epoch': 152, 'best_valid_score': 0.7693748536642473, 'test_accuracy': 0.7672085089303632, 'selected_test_score': 0.7702689143086494}


 63%|██████▎   | 157/250 [03:27<02:04,  1.34s/it]

{'epoch': 157, 'valid_accuracy': 0.7648091781784125, 'best_valid_epoch': 152, 'best_valid_score': 0.7693748536642473, 'test_accuracy': 0.7646498093517962, 'selected_test_score': 0.7702689143086494}


 63%|██████▎   | 158/250 [03:28<02:03,  1.34s/it]

{'epoch': 158, 'valid_accuracy': 0.7689065792554437, 'best_valid_epoch': 152, 'best_valid_score': 0.7693748536642473, 'test_accuracy': 0.7677102147300823, 'selected_test_score': 0.7702689143086494}


 64%|██████▎   | 159/250 [03:30<02:00,  1.32s/it]

{'epoch': 159, 'valid_accuracy': 0.7713650199016624, 'best_valid_epoch': 159, 'best_valid_score': 0.7713650199016624, 'test_accuracy': 0.7702689143086494, 'selected_test_score': 0.7702689143086494}


 64%|██████▍   | 160/250 [03:31<01:59,  1.33s/it]

{'epoch': 160, 'valid_accuracy': 0.7600093654881761, 'best_valid_epoch': 159, 'best_valid_score': 0.7713650199016624, 'test_accuracy': 0.7603351394742123, 'selected_test_score': 0.7702689143086494}


 64%|██████▍   | 161/250 [03:32<01:58,  1.33s/it]

{'epoch': 161, 'valid_accuracy': 0.770779676890658, 'best_valid_epoch': 159, 'best_valid_score': 0.7713650199016624, 'test_accuracy': 0.7705699377884808, 'selected_test_score': 0.7702689143086494}


 65%|██████▍   | 162/250 [03:34<01:58,  1.35s/it]

{'epoch': 162, 'valid_accuracy': 0.7732381175368767, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7717740317078066, 'selected_test_score': 0.7717740317078066}


 65%|██████▌   | 163/250 [03:35<01:58,  1.36s/it]

{'epoch': 163, 'valid_accuracy': 0.7639896979630063, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7642986152919927, 'selected_test_score': 0.7717740317078066}


 66%|██████▌   | 164/250 [03:36<01:56,  1.36s/it]

{'epoch': 164, 'valid_accuracy': 0.7571997190353548, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7587798514950833, 'selected_test_score': 0.7717740317078066}


 66%|██████▌   | 165/250 [03:38<01:54,  1.35s/it]

{'epoch': 165, 'valid_accuracy': 0.7582533364551627, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.758930363234999, 'selected_test_score': 0.7717740317078066}


 66%|██████▋   | 166/250 [03:39<01:53,  1.35s/it]

{'epoch': 166, 'valid_accuracy': 0.7723015687192695, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7709713024282561, 'selected_test_score': 0.7717740317078066}


 67%|██████▋   | 167/250 [03:40<01:51,  1.35s/it]

{'epoch': 167, 'valid_accuracy': 0.7703114024818544, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7700180614087899, 'selected_test_score': 0.7717740317078066}


 67%|██████▋   | 168/250 [03:42<01:50,  1.34s/it]

{'epoch': 168, 'valid_accuracy': 0.7679700304378366, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7679610676299418, 'selected_test_score': 0.7717740317078066}


 68%|██████▊   | 169/250 [03:43<01:48,  1.34s/it]

{'epoch': 169, 'valid_accuracy': 0.7694919222664481, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7698173790889022, 'selected_test_score': 0.7717740317078066}


 68%|██████▊   | 170/250 [03:44<01:47,  1.34s/it]

{'epoch': 170, 'valid_accuracy': 0.7715991571060642, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7708709612683122, 'selected_test_score': 0.7717740317078066}


 68%|██████▊   | 171/250 [03:46<01:46,  1.35s/it]

{'epoch': 171, 'valid_accuracy': 0.7732381175368767, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7726269315673289, 'selected_test_score': 0.7717740317078066}


 69%|██████▉   | 172/250 [03:47<01:44,  1.35s/it]

{'epoch': 172, 'valid_accuracy': 0.7684383048466401, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7679108970499698, 'selected_test_score': 0.7717740317078066}


 69%|██████▉   | 173/250 [03:48<01:43,  1.34s/it]

{'epoch': 173, 'valid_accuracy': 0.765394521189417, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7648504916716837, 'selected_test_score': 0.7717740317078066}


 70%|██████▉   | 174/250 [03:50<01:43,  1.36s/it]

{'epoch': 174, 'valid_accuracy': 0.7703114024818544, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7695163556090708, 'selected_test_score': 0.7717740317078066}


 70%|███████   | 175/250 [03:51<01:41,  1.36s/it]

{'epoch': 175, 'valid_accuracy': 0.7662140014048232, 'best_valid_epoch': 162, 'best_valid_score': 0.7732381175368767, 'test_accuracy': 0.7672085089303632, 'selected_test_score': 0.7717740317078066}


 70%|███████   | 176/250 [03:52<01:38,  1.33s/it]

{'epoch': 176, 'valid_accuracy': 0.7755794895808944, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7754364840457556, 'selected_test_score': 0.7754364840457556}


 71%|███████   | 177/250 [03:54<01:37,  1.34s/it]

{'epoch': 177, 'valid_accuracy': 0.7710138140950598, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7698173790889022, 'selected_test_score': 0.7754364840457556}


 71%|███████   | 178/250 [03:55<01:36,  1.34s/it]

{'epoch': 178, 'valid_accuracy': 0.7724186373214704, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7716736905478627, 'selected_test_score': 0.7754364840457556}


 72%|███████▏  | 179/250 [03:56<01:35,  1.34s/it]

{'epoch': 179, 'valid_accuracy': 0.7656286583938188, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7651013445715432, 'selected_test_score': 0.7754364840457556}


 72%|███████▏  | 180/250 [03:58<01:33,  1.34s/it]

{'epoch': 180, 'valid_accuracy': 0.7730039803324749, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7712723259080875, 'selected_test_score': 0.7754364840457556}


 72%|███████▏  | 181/250 [03:59<01:32,  1.34s/it]

{'epoch': 181, 'valid_accuracy': 0.7637555607586045, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7627934978928357, 'selected_test_score': 0.7754364840457556}


 73%|███████▎  | 182/250 [04:01<01:30,  1.34s/it]

{'epoch': 182, 'valid_accuracy': 0.7728869117302739, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7713224964880594, 'selected_test_score': 0.7754364840457556}


 73%|███████▎  | 183/250 [04:02<01:29,  1.34s/it]

{'epoch': 183, 'valid_accuracy': 0.7733551861390775, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7727774433072446, 'selected_test_score': 0.7754364840457556}


 74%|███████▎  | 184/250 [04:03<01:28,  1.34s/it]

{'epoch': 184, 'valid_accuracy': 0.770662608288457, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.769967890828818, 'selected_test_score': 0.7754364840457556}


 74%|███████▍  | 185/250 [04:05<01:27,  1.34s/it]

{'epoch': 185, 'valid_accuracy': 0.7755794895808944, 'best_valid_epoch': 176, 'best_valid_score': 0.7755794895808944, 'test_accuracy': 0.7745334136062613, 'selected_test_score': 0.7754364840457556}


 74%|███████▍  | 186/250 [04:06<01:27,  1.37s/it]

{'epoch': 186, 'valid_accuracy': 0.7768672442051042, 'best_valid_epoch': 186, 'best_valid_score': 0.7768672442051042, 'test_accuracy': 0.7755869957856713, 'selected_test_score': 0.7755869957856713}


 75%|███████▍  | 187/250 [04:07<01:26,  1.37s/it]

{'epoch': 187, 'valid_accuracy': 0.775930695387497, 'best_valid_epoch': 186, 'best_valid_score': 0.7768672442051042, 'test_accuracy': 0.7750852899859523, 'selected_test_score': 0.7755869957856713}


 75%|███████▌  | 188/250 [04:09<01:24,  1.36s/it]

{'epoch': 188, 'valid_accuracy': 0.7708967454928588, 'best_valid_epoch': 186, 'best_valid_score': 0.7768672442051042, 'test_accuracy': 0.7704194260485652, 'selected_test_score': 0.7755869957856713}


 76%|███████▌  | 189/250 [04:10<01:22,  1.36s/it]

{'epoch': 189, 'valid_accuracy': 0.7667993444158276, 'best_valid_epoch': 186, 'best_valid_score': 0.7768672442051042, 'test_accuracy': 0.7655027092113185, 'selected_test_score': 0.7755869957856713}


 76%|███████▌  | 190/250 [04:11<01:21,  1.36s/it]

{'epoch': 190, 'valid_accuracy': 0.7713650199016624, 'best_valid_epoch': 186, 'best_valid_score': 0.7768672442051042, 'test_accuracy': 0.7697672085089303, 'selected_test_score': 0.7755869957856713}


 76%|███████▋  | 191/250 [04:13<01:19,  1.35s/it]

{'epoch': 191, 'valid_accuracy': 0.7767501756029033, 'best_valid_epoch': 186, 'best_valid_score': 0.7768672442051042, 'test_accuracy': 0.7776941601444912, 'selected_test_score': 0.7755869957856713}


 77%|███████▋  | 192/250 [04:14<01:17,  1.34s/it]

{'epoch': 192, 'valid_accuracy': 0.777101381409506, 'best_valid_epoch': 192, 'best_valid_score': 0.777101381409506, 'test_accuracy': 0.7768412602849689, 'selected_test_score': 0.7768412602849689}


 77%|███████▋  | 193/250 [04:15<01:15,  1.33s/it]

{'epoch': 193, 'valid_accuracy': 0.7786232732381175, 'best_valid_epoch': 193, 'best_valid_score': 0.7786232732381175, 'test_accuracy': 0.7774934778246037, 'selected_test_score': 0.7774934778246037}


 78%|███████▊  | 194/250 [04:17<01:14,  1.33s/it]

{'epoch': 194, 'valid_accuracy': 0.7787403418403184, 'best_valid_epoch': 194, 'best_valid_score': 0.7787403418403184, 'test_accuracy': 0.7776439895645194, 'selected_test_score': 0.7776439895645194}


 78%|███████▊  | 195/250 [04:18<01:13,  1.34s/it]

{'epoch': 195, 'valid_accuracy': 0.7785062046359167, 'best_valid_epoch': 194, 'best_valid_score': 0.7787403418403184, 'test_accuracy': 0.7779951836243227, 'selected_test_score': 0.7776439895645194}


 78%|███████▊  | 196/250 [04:19<01:12,  1.34s/it]

{'epoch': 196, 'valid_accuracy': 0.7754624209786936, 'best_valid_epoch': 194, 'best_valid_score': 0.7787403418403184, 'test_accuracy': 0.7765402368051375, 'selected_test_score': 0.7776439895645194}


 79%|███████▉  | 197/250 [04:21<01:11,  1.34s/it]

{'epoch': 197, 'valid_accuracy': 0.7787403418403184, 'best_valid_epoch': 194, 'best_valid_score': 0.7787403418403184, 'test_accuracy': 0.7774433072446317, 'selected_test_score': 0.7776439895645194}


 79%|███████▉  | 198/250 [04:22<01:10,  1.36s/it]

{'epoch': 198, 'valid_accuracy': 0.779208616249122, 'best_valid_epoch': 198, 'best_valid_score': 0.779208616249122, 'test_accuracy': 0.7782962071041541, 'selected_test_score': 0.7782962071041541}


 80%|███████▉  | 199/250 [04:23<01:04,  1.27s/it]

{'epoch': 199, 'valid_accuracy': 0.7794427534535238, 'best_valid_epoch': 199, 'best_valid_score': 0.7794427534535238, 'test_accuracy': 0.778898254063817, 'selected_test_score': 0.778898254063817}


 80%|████████  | 200/250 [04:24<00:55,  1.11s/it]

{'epoch': 200, 'valid_accuracy': 0.7747600093654882, 'best_valid_epoch': 199, 'best_valid_score': 0.7794427534535238, 'test_accuracy': 0.7718743728677504, 'selected_test_score': 0.778898254063817}


 80%|████████  | 201/250 [04:25<00:49,  1.00s/it]

{'epoch': 201, 'valid_accuracy': 0.7785062046359167, 'best_valid_epoch': 199, 'best_valid_score': 0.7794427534535238, 'test_accuracy': 0.7775938189845475, 'selected_test_score': 0.778898254063817}


 81%|████████  | 202/250 [04:25<00:44,  1.07it/s]

{'epoch': 202, 'valid_accuracy': 0.77499414656989, 'best_valid_epoch': 199, 'best_valid_score': 0.7794427534535238, 'test_accuracy': 0.7745835841862332, 'selected_test_score': 0.778898254063817}


 81%|████████  | 203/250 [04:26<00:41,  1.14it/s]

{'epoch': 203, 'valid_accuracy': 0.7753453523764926, 'best_valid_epoch': 199, 'best_valid_score': 0.7794427534535238, 'test_accuracy': 0.773128637367048, 'selected_test_score': 0.778898254063817}


 82%|████████▏ | 204/250 [04:27<00:38,  1.20it/s]

{'epoch': 204, 'valid_accuracy': 0.7756965581830952, 'best_valid_epoch': 199, 'best_valid_score': 0.7794427534535238, 'test_accuracy': 0.773078466787076, 'selected_test_score': 0.778898254063817}


 82%|████████▏ | 205/250 [04:28<00:36,  1.23it/s]

{'epoch': 205, 'valid_accuracy': 0.7800280964645282, 'best_valid_epoch': 205, 'best_valid_score': 0.7800280964645282, 'test_accuracy': 0.7795003010234798, 'selected_test_score': 0.7795003010234798}


 82%|████████▏ | 206/250 [04:29<00:42,  1.03it/s]

{'epoch': 206, 'valid_accuracy': 0.7730039803324749, 'best_valid_epoch': 205, 'best_valid_score': 0.7800280964645282, 'test_accuracy': 0.7699177202488461, 'selected_test_score': 0.7795003010234798}


 83%|████████▎ | 207/250 [04:30<00:47,  1.10s/it]

{'epoch': 207, 'valid_accuracy': 0.7761648325918988, 'best_valid_epoch': 205, 'best_valid_score': 0.7800280964645282, 'test_accuracy': 0.7752358017258679, 'selected_test_score': 0.7795003010234798}


 83%|████████▎ | 208/250 [04:32<00:49,  1.18s/it]

{'epoch': 208, 'valid_accuracy': 0.775930695387497, 'best_valid_epoch': 205, 'best_valid_score': 0.7800280964645282, 'test_accuracy': 0.7739815372265704, 'selected_test_score': 0.7795003010234798}


 84%|████████▎ | 209/250 [04:33<00:51,  1.24s/it]

{'epoch': 209, 'valid_accuracy': 0.7755794895808944, 'best_valid_epoch': 205, 'best_valid_score': 0.7800280964645282, 'test_accuracy': 0.7721252257676099, 'selected_test_score': 0.7795003010234798}


 84%|████████▍ | 210/250 [04:35<00:52,  1.30s/it]

{'epoch': 210, 'valid_accuracy': 0.7787403418403184, 'best_valid_epoch': 205, 'best_valid_score': 0.7800280964645282, 'test_accuracy': 0.778848083483845, 'selected_test_score': 0.7795003010234798}


 84%|████████▍ | 211/250 [04:36<00:51,  1.32s/it]

{'epoch': 211, 'valid_accuracy': 0.781315851088738, 'best_valid_epoch': 211, 'best_valid_score': 0.781315851088738, 'test_accuracy': 0.7800020068231989, 'selected_test_score': 0.7800020068231989}


 85%|████████▍ | 212/250 [04:37<00:51,  1.34s/it]

{'epoch': 212, 'valid_accuracy': 0.7820182627019433, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7805538832028899, 'selected_test_score': 0.7805538832028899}


 85%|████████▌ | 213/250 [04:39<00:50,  1.36s/it]

{'epoch': 213, 'valid_accuracy': 0.7658627955982206, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7665562913907285, 'selected_test_score': 0.7805538832028899}


 86%|████████▌ | 214/250 [04:40<00:49,  1.36s/it]

{'epoch': 214, 'valid_accuracy': 0.7775696558183095, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7753361428858118, 'selected_test_score': 0.7805538832028899}


 86%|████████▌ | 215/250 [04:42<00:47,  1.37s/it]

{'epoch': 215, 'valid_accuracy': 0.7685553734488411, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7691149909692956, 'selected_test_score': 0.7805538832028899}


 86%|████████▋ | 216/250 [04:43<00:46,  1.37s/it]

{'epoch': 216, 'valid_accuracy': 0.7789744790447202, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7797511539233394, 'selected_test_score': 0.7805538832028899}


 87%|████████▋ | 217/250 [04:44<00:45,  1.37s/it]

{'epoch': 217, 'valid_accuracy': 0.7786232732381175, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7770419426048565, 'selected_test_score': 0.7805538832028899}


 87%|████████▋ | 218/250 [04:46<00:43,  1.37s/it]

{'epoch': 218, 'valid_accuracy': 0.7804963708733318, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7798013245033113, 'selected_test_score': 0.7805538832028899}


 88%|████████▊ | 219/250 [04:47<00:42,  1.36s/it]

{'epoch': 219, 'valid_accuracy': 0.7776867244205105, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7728777844671885, 'selected_test_score': 0.7805538832028899}


 88%|████████▊ | 220/250 [04:48<00:40,  1.36s/it]

{'epoch': 220, 'valid_accuracy': 0.7774525872161087, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7743327312863737, 'selected_test_score': 0.7805538832028899}


 88%|████████▊ | 221/250 [04:50<00:39,  1.37s/it]

{'epoch': 221, 'valid_accuracy': 0.7774525872161087, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7734798314268513, 'selected_test_score': 0.7805538832028899}


 89%|████████▉ | 222/250 [04:51<00:39,  1.40s/it]

{'epoch': 222, 'valid_accuracy': 0.7778037930227113, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7758378486855309, 'selected_test_score': 0.7805538832028899}


 89%|████████▉ | 223/250 [04:53<00:37,  1.39s/it]

{'epoch': 223, 'valid_accuracy': 0.7819011940997425, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7811057595825808, 'selected_test_score': 0.7805538832028899}


 90%|████████▉ | 224/250 [04:54<00:35,  1.38s/it]

{'epoch': 224, 'valid_accuracy': 0.779208616249122, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.779951836243227, 'selected_test_score': 0.7805538832028899}


 90%|█████████ | 225/250 [04:55<00:34,  1.38s/it]

{'epoch': 225, 'valid_accuracy': 0.7800280964645282, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7800020068231989, 'selected_test_score': 0.7805538832028899}


 90%|█████████ | 226/250 [04:57<00:33,  1.38s/it]

{'epoch': 226, 'valid_accuracy': 0.7794427534535238, 'best_valid_epoch': 212, 'best_valid_score': 0.7820182627019433, 'test_accuracy': 0.7798514950832831, 'selected_test_score': 0.7805538832028899}


 91%|█████████ | 227/250 [04:58<00:31,  1.38s/it]

{'epoch': 227, 'valid_accuracy': 0.7830718801217513, 'best_valid_epoch': 227, 'best_valid_score': 0.7830718801217513, 'test_accuracy': 0.7822596829219346, 'selected_test_score': 0.7822596829219346}


 91%|█████████ | 228/250 [04:59<00:30,  1.38s/it]

{'epoch': 228, 'valid_accuracy': 0.7719503629126668, 'best_valid_epoch': 227, 'best_valid_score': 0.7830718801217513, 'test_accuracy': 0.7711218141681718, 'selected_test_score': 0.7822596829219346}


 92%|█████████▏| 229/250 [05:01<00:28,  1.38s/it]

{'epoch': 229, 'valid_accuracy': 0.783423085928354, 'best_valid_epoch': 229, 'best_valid_score': 0.783423085928354, 'test_accuracy': 0.7816576359622717, 'selected_test_score': 0.7816576359622717}


 92%|█████████▏| 230/250 [05:02<00:27,  1.38s/it]

{'epoch': 230, 'valid_accuracy': 0.7810817138843362, 'best_valid_epoch': 229, 'best_valid_score': 0.783423085928354, 'test_accuracy': 0.780453542042946, 'selected_test_score': 0.7816576359622717}


 92%|█████████▏| 231/250 [05:04<00:26,  1.38s/it]

{'epoch': 231, 'valid_accuracy': 0.7836572231327558, 'best_valid_epoch': 231, 'best_valid_score': 0.7836572231327558, 'test_accuracy': 0.7823098535019065, 'selected_test_score': 0.7823098535019065}


 93%|█████████▎| 232/250 [05:05<00:24,  1.37s/it]

{'epoch': 232, 'valid_accuracy': 0.7775696558183095, 'best_valid_epoch': 231, 'best_valid_score': 0.7836572231327558, 'test_accuracy': 0.7739313666465985, 'selected_test_score': 0.7823098535019065}


 93%|█████████▎| 233/250 [05:06<00:23,  1.37s/it]

{'epoch': 233, 'valid_accuracy': 0.7829548115195505, 'best_valid_epoch': 231, 'best_valid_score': 0.7836572231327558, 'test_accuracy': 0.7824603652418222, 'selected_test_score': 0.7823098535019065}


 94%|█████████▎| 234/250 [05:08<00:22,  1.40s/it]

{'epoch': 234, 'valid_accuracy': 0.7698431280730508, 'best_valid_epoch': 231, 'best_valid_score': 0.7836572231327558, 'test_accuracy': 0.7704194260485652, 'selected_test_score': 0.7823098535019065}


 94%|█████████▍| 235/250 [05:09<00:20,  1.39s/it]

{'epoch': 235, 'valid_accuracy': 0.7827206743151487, 'best_valid_epoch': 231, 'best_valid_score': 0.7836572231327558, 'test_accuracy': 0.7829620710415413, 'selected_test_score': 0.7823098535019065}


 94%|█████████▍| 236/250 [05:10<00:19,  1.37s/it]

{'epoch': 236, 'valid_accuracy': 0.7803793022711308, 'best_valid_epoch': 231, 'best_valid_score': 0.7836572231327558, 'test_accuracy': 0.7783463776841261, 'selected_test_score': 0.7823098535019065}


 95%|█████████▍| 237/250 [05:12<00:17,  1.37s/it]

{'epoch': 237, 'valid_accuracy': 0.7848279091547647, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7824101946618502, 'selected_test_score': 0.7824101946618502}


 95%|█████████▌| 238/250 [05:13<00:16,  1.38s/it]

{'epoch': 238, 'valid_accuracy': 0.7793256848513229, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7747842665061208, 'selected_test_score': 0.7824101946618502}


 96%|█████████▌| 239/250 [05:15<00:15,  1.37s/it]

{'epoch': 239, 'valid_accuracy': 0.7801451650667292, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7779951836243227, 'selected_test_score': 0.7824101946618502}


 96%|█████████▌| 240/250 [05:16<00:13,  1.37s/it]

{'epoch': 240, 'valid_accuracy': 0.779208616249122, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7791992775436484, 'selected_test_score': 0.7824101946618502}


 96%|█████████▋| 241/250 [05:17<00:12,  1.37s/it]

{'epoch': 241, 'valid_accuracy': 0.7807305080777336, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7801023479831427, 'selected_test_score': 0.7824101946618502}


 97%|█████████▋| 242/250 [05:19<00:11,  1.38s/it]

{'epoch': 242, 'valid_accuracy': 0.7744088035588855, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7718743728677504, 'selected_test_score': 0.7824101946618502}


 97%|█████████▋| 243/250 [05:20<00:09,  1.38s/it]

{'epoch': 243, 'valid_accuracy': 0.7801451650667292, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7790989363837046, 'selected_test_score': 0.7824101946618502}


 98%|█████████▊| 244/250 [05:22<00:08,  1.38s/it]

{'epoch': 244, 'valid_accuracy': 0.7815499882931398, 'best_valid_epoch': 237, 'best_valid_score': 0.7848279091547647, 'test_accuracy': 0.7816074653822999, 'selected_test_score': 0.7824101946618502}


 98%|█████████▊| 245/250 [05:23<00:06,  1.38s/it]

{'epoch': 245, 'valid_accuracy': 0.7849449777569656, 'best_valid_epoch': 245, 'best_valid_score': 0.7849449777569656, 'test_accuracy': 0.7835641180012041, 'selected_test_score': 0.7835641180012041}


 98%|█████████▊| 246/250 [05:24<00:05,  1.39s/it]

{'epoch': 246, 'valid_accuracy': 0.7816670568953407, 'best_valid_epoch': 245, 'best_valid_score': 0.7849449777569656, 'test_accuracy': 0.7813566124824403, 'selected_test_score': 0.7835641180012041}


 99%|█████████▉| 247/250 [05:26<00:04,  1.39s/it]

{'epoch': 247, 'valid_accuracy': 0.7848279091547647, 'best_valid_epoch': 245, 'best_valid_score': 0.7849449777569656, 'test_accuracy': 0.7838149709010637, 'selected_test_score': 0.7835641180012041}


 99%|█████████▉| 248/250 [05:27<00:02,  1.39s/it]

{'epoch': 248, 'valid_accuracy': 0.7793256848513229, 'best_valid_epoch': 245, 'best_valid_score': 0.7849449777569656, 'test_accuracy': 0.779951836243227, 'selected_test_score': 0.7835641180012041}


100%|█████████▉| 249/250 [05:28<00:01,  1.39s/it]

{'epoch': 249, 'valid_accuracy': 0.7800280964645282, 'best_valid_epoch': 245, 'best_valid_score': 0.7849449777569656, 'test_accuracy': 0.7798013245033113, 'selected_test_score': 0.7835641180012041}


100%|██████████| 250/250 [05:30<00:00,  1.32s/it]

{'epoch': 250, 'valid_accuracy': 0.7811987824865371, 'best_valid_epoch': 245, 'best_valid_score': 0.7849449777569656, 'test_accuracy': 0.7798514950832831, 'selected_test_score': 0.7835641180012041}





In [87]:
model

DOFENClassifier(
  (condition_generation): ConditionGeneration(
    (phi_1): ModuleDict(
      (num): Sequential(
        (0): Reshape()
        (1): FastGroupConv1d(7, 448, kernel_size=(1,), stride=(1,), groups=7)
        (2): Sigmoid()
        (3): Reshape()
      )
    )
  )
  (proj): Linear(in_features=1, out_features=16, bias=True)
  (rodt_construction): rODTConstruction()
  (norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (attn): MultiheadAttention(
    (qw_proj): Linear(in_features=16, out_features=16, bias=True)
    (kw_proj): Linear(in_features=16, out_features=16, bias=True)
    (vw_proj): Linear(in_features=16, out_features=16, bias=True)
    (ow_proj): Linear(in_features=16, out_features=1, bias=True)
    (qE_proj): Linear(in_features=16, out_features=16, bias=True)
    (kE_proj): Linear(in_features=16, out_features=16, bias=True)
    (vE_proj): Linear(in_features=16, out_features=16, bias=True)
    (oE_proj): Linear(in_features=16, out_features=16, bias=True)

In [88]:
model.attn

MultiheadAttention(
  (qw_proj): Linear(in_features=16, out_features=16, bias=True)
  (kw_proj): Linear(in_features=16, out_features=16, bias=True)
  (vw_proj): Linear(in_features=16, out_features=16, bias=True)
  (ow_proj): Linear(in_features=16, out_features=1, bias=True)
  (qE_proj): Linear(in_features=16, out_features=16, bias=True)
  (kE_proj): Linear(in_features=16, out_features=16, bias=True)
  (vE_proj): Linear(in_features=16, out_features=16, bias=True)
  (oE_proj): Linear(in_features=16, out_features=16, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [89]:
model.cuda()(torch.tensor(train_X[:256]).cuda())

{'pred': tensor([[ 0.2512, -0.0253],
         [-0.4203,  0.9439],
         [ 0.0696,  0.2283],
         [ 0.4746, -0.3041],
         [ 0.8893, -0.8632],
         [ 0.7030, -0.6088],
         [ 0.8161, -0.7506],
         [ 1.2206, -1.3292],
         [ 0.3068, -0.0989],
         [ 1.1070, -1.1566],
         [ 0.4825, -0.3156],
         [-0.3567,  0.8564],
         [ 1.0180, -1.0287],
         [-0.9935,  1.7610],
         [ 0.0696,  0.2234],
         [ 0.2868, -0.0663],
         [ 0.5032, -0.3719],
         [ 0.3093, -0.1194],
         [ 0.2660, -0.0410],
         [-0.8812,  1.6197],
         [-0.9912,  1.7624],
         [ 1.0766, -1.1170],
         [ 0.6953, -0.6056],
         [-0.7315,  1.4160],
         [ 0.2244,  0.0132],
         [ 1.1348, -1.2020],
         [ 0.0075,  0.3226],
         [-0.6082,  1.2394],
         [-0.1344,  0.5107],
         [ 0.6661, -0.5536],
         [ 0.6013, -0.4703],
         [-0.0216,  0.3661],
         [ 1.0307, -1.0517],
         [ 1.4218, -1.6203],
      

In [90]:
activations = {}

def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

handle = model.get_submodule("attn.oE_proj").register_forward_hook(get_activation("oE_proj"))

In [91]:
model.cuda()(torch.tensor(train_X[:256]).cuda())

{'pred': tensor([[ 0.2512, -0.0253],
         [-0.4203,  0.9439],
         [ 0.0696,  0.2283],
         [ 0.4746, -0.3041],
         [ 0.8893, -0.8632],
         [ 0.7030, -0.6088],
         [ 0.8161, -0.7506],
         [ 1.2206, -1.3292],
         [ 0.3068, -0.0989],
         [ 1.1070, -1.1566],
         [ 0.4825, -0.3156],
         [-0.3567,  0.8564],
         [ 1.0180, -1.0287],
         [-0.9935,  1.7610],
         [ 0.0696,  0.2234],
         [ 0.2868, -0.0663],
         [ 0.5032, -0.3719],
         [ 0.3093, -0.1194],
         [ 0.2660, -0.0410],
         [-0.8812,  1.6197],
         [-0.9912,  1.7624],
         [ 1.0766, -1.1170],
         [ 0.6953, -0.6056],
         [-0.7315,  1.4160],
         [ 0.2244,  0.0132],
         [ 1.1348, -1.2020],
         [ 0.0075,  0.3226],
         [-0.6082,  1.2394],
         [-0.1344,  0.5107],
         [ 0.6661, -0.5536],
         [ 0.6013, -0.4703],
         [-0.0216,  0.3661],
         [ 1.0307, -1.0517],
         [ 1.4218, -1.6203],
      

In [92]:
oE_proj_output = activations["oE_proj"]
print(oE_proj_output.shape)

torch.Size([28672, 4, 16])


In [93]:
oE_proj_output = oE_proj_output.reshape(256, -1, 4, 16)

In [94]:
oE_proj_output.shape

torch.Size([256, 112, 4, 16])

In [95]:
x = oE_proj_output[0].view(112, -1)

In [96]:
x_ = x / ((x ** 2).sum(dim=1) ** 0.5)[:, None]
sim = x_ @ x_.T
sim

tensor([[1.0000, 0.9864, 0.9966,  ..., 0.8331, 0.9967, 0.9446],
        [0.9864, 1.0000, 0.9699,  ..., 0.7317, 0.9705, 0.8786],
        [0.9966, 0.9699, 1.0000,  ..., 0.8756, 1.0000, 0.9684],
        ...,
        [0.8331, 0.7317, 0.8756,  ..., 1.0000, 0.8742, 0.9683],
        [0.9967, 0.9705, 1.0000,  ..., 0.8742, 1.0000, 0.9677],
        [0.9446, 0.8786, 0.9684,  ..., 0.9683, 0.9677, 1.0000]],
       device='cuda:0')

In [97]:
sim[0]

tensor([1.0000, 0.9864, 0.9966, 0.9950, 0.9987, 0.9916, 0.8117, 0.9820, 0.8812,
        0.9980, 0.8736, 0.9975, 0.9953, 0.7756, 0.8760, 0.9834, 0.8134, 0.9957,
        0.9969, 0.9995, 0.9883, 0.9851, 0.9917, 0.9916, 0.9787, 0.8838, 0.9921,
        0.9646, 0.9666, 0.9698, 0.9916, 0.9912, 0.9725, 0.8803, 0.8372, 0.9318,
        0.9742, 0.8234, 0.9980, 0.9429, 0.9838, 0.9895, 0.8261, 0.9684, 0.9800,
        0.7999, 0.8236, 0.7952, 0.7878, 0.9967, 0.9602, 0.9996, 0.9560, 0.8429,
        0.9823, 0.9934, 0.9196, 0.9720, 0.8384, 0.9974, 0.9914, 0.9196, 0.9915,
        0.9987, 0.9260, 0.9399, 0.9613, 0.9955, 0.9972, 0.9868, 0.8621, 0.9961,
        0.9988, 0.9552, 0.9926, 0.9950, 0.9754, 0.8788, 0.9974, 0.8764, 0.9973,
        0.8756, 0.7560, 0.9939, 0.9600, 0.9306, 0.9128, 0.9945, 0.9986, 0.9579,
        0.8050, 0.8734, 0.8130, 0.9218, 0.9924, 0.9981, 0.9805, 0.8486, 0.9330,
        0.9673, 0.8925, 0.9559, 0.9933, 0.9302, 0.8561, 0.8310, 0.8679, 0.9961,
        0.9639, 0.8331, 0.9967, 0.9446],