In [24]:
from pathlib import Path

from src.metric import cls_eval_funs, reg_eval_funs
from src.trainer import Trainer
from src.utils import load_pkl_data, set_random_seed

In [25]:
class Config:
    data_dir = '/home/jiawei/Desktop/github/DOFEN/tabular-benchmark/tabular_benchmark_data'
    data_id = '361061'
    n_epoch = 300
    batch_size = 256
    target_transform = False

In [33]:
from typing import Dict, List, Optional

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from src.models.base import BaseClassifier, BaseRegressor
from src.models.layers import ResidualLayer, ReGLU


class Reshape(nn.Module):
    def __init__(self, *args: int) -> None:
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.reshape(self.shape)


class FastGroupConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        self.fast_mode = kwargs.pop('fast_mode')
        nn.Conv1d.__init__(self, *args, **kwargs)
        if self.groups > self.fast_mode:
            self.weight = nn.Parameter(
                self.weight.reshape(
                    self.groups, self.out_channels // self.groups, self.in_channels // self.groups, 1
                ).permute(3, 0, 2, 1)
            )
            self.bias = nn.Parameter(
                self.bias.unsqueeze(0).unsqueeze(-1)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.groups > self.fast_mode:
            x = x.reshape(-1, self.groups, self.in_channels // self.groups, 1)
            return (x * self.weight).sum(2, keepdims=True).permute(0, 1, 3, 2).reshape(-1, self.out_channels, 1) + self.bias
        else:
            return self._conv_forward(x, self.weight, self.bias)


class ConditionGeneration(nn.Module):
    def __init__(
        self,
        category_column_count: List[int],
        n_cond: int = 128,
        n_hidden: int = 4,
        categorical_optimized: bool = False,
        fast_mode: int = 64,
    ):
        super(ConditionGeneration, self).__init__()
        self.fast_mode = fast_mode
        self.categorical_optimized = categorical_optimized

        index_info = self.extract_feature_metadata(category_column_count)
        self.numerical_index = index_info['numerical_index']
        self.categorical_index = index_info['categorical_index']
        self.categorical_count = index_info['categorical_count']

        categorical_offset = torch.tensor([0] + np.cumsum(self.categorical_count).tolist()[:-1]).long()
        self.register_buffer('categorical_offset', categorical_offset)

        self.n_cond = n_cond
        self.n_hidden = n_hidden
        self.phi_1 = self.get_phi_1()

    def extract_feature_metadata(self, category_column_count: List[int]) -> Dict[str, List[int]]:
        numerical_index = [i for i, count in enumerate(category_column_count) if count == -1]
        categorical_index = [i for i, count in enumerate(category_column_count) if count != -1]
        categorical_count = [count for count in category_column_count if count != -1]
        return {
            'numerical_index': numerical_index,
            'categorical_index': categorical_index,
            'categorical_count': categorical_count,
        }

    def get_phi_1(self) -> nn.ModuleDict:
        phi_1 = nn.ModuleDict()
        if len(self.numerical_index):
            phi_1['num'] = nn.Sequential(
                # input = (b, n_num_col)
                # output = (b, n_num_col, n_cond)
                Reshape(-1, len(self.numerical_index), 1),
                FastGroupConv1d(
                    len(self.numerical_index),
                    len(self.numerical_index) * self.n_cond * self.n_hidden,
                    kernel_size=1,
                    groups=len(self.numerical_index),
                    fast_mode=self.fast_mode
                ), # (b, n_num_col, 1) -> (b, n_num_col * n_cond, 1)
                # nn.Sigmoid(),
                Reshape(-1, len(self.numerical_index), self.n_cond, self.n_hidden)
            )
        if len(self.categorical_index):
            phi_1['cat'] = nn.ModuleDict()
            phi_1['cat']['embedder'] = nn.Embedding(sum(self.categorical_count), self.n_cond * self.n_hidden)            
            phi_1['cat']['mapper'] = nn.Sequential(
                # input = (b, n_cat_col, n_cond)
                # output = (b, n_cat_col, n_cond)
                Reshape(-1, len(self.categorical_index) * self.n_cond  * self.n_hidden, 1),
                nn.GroupNorm(len(self.categorical_index), len(self.categorical_index) * self.n_cond  * self.n_hidden),
                FastGroupConv1d(
                    len(self.categorical_index) * self.n_cond  * self.n_hidden,
                    len(self.categorical_index) * self.n_cond  * self.n_hidden,
                    kernel_size=1,
                    groups=len(self.categorical_index) * self.n_cond  * self.n_hidden if self.categorical_optimized else len(self.categorical_index),
                    fast_mode=self.fast_mode),                
                nn.Sigmoid(),
                Reshape(-1, len(self.categorical_index), self.n_cond, self.n_hidden)
            )
        return phi_1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        M = []

        if len(self.numerical_index):
            num_x = x[:, self.numerical_index].float()
            num_sample_emb = self.phi_1['num'](num_x)
            M.append(num_sample_emb)

        if len(self.categorical_index):
            cat_x = x[:, self.categorical_index].long() + self.cat_offset
            cat_sample_emb = self.phi_1['cat']['mapper'](self.phi_1['cat']['embedder'](cat_x))
            M.append(cat_sample_emb)

        M = torch.cat(M, dim=1) # (b, n_col, n_cond, n_hidden)
        M = M.permute(0, 2, 1, 3) # (b, n_cond, n_col, n_hidden)
        return M


class rODTConstruction(nn.Module):
    def __init__(self, n_cond: int, n_col: int, d: int) -> None:
        super().__init__()
        self.permutator = torch.rand(n_cond * n_col).argsort(-1)
        self.d = d

    def forward(self, M: torch.Tensor) -> torch.Tensor:
        b, _, _, embed_dim = M.shape
        return M.reshape(b, -1, embed_dim)[:, self.permutator, :].reshape(b, -1, self.d, embed_dim)


class rODTForestConstruction(nn.Module):
    def __init__(
        self,
        n_col: int,
        n_rodt: int,
        n_cond: int,
        n_estimator: int,
        n_head: int = 1,
        n_hidden: int = 128,
        n_forest: int = 100,
        dropout: float = 0.0,
        fast_mode: int = 64,
        device = torch.device('cuda')
    ) -> None:

        super().__init__()
        self.device = device
        self.n_estimator = n_estimator
        self.n_forest = n_forest
        self.n_rodt = n_rodt
        self.n_head = n_head
        self.n_hidden = n_hidden

        self.sample_without_replacement_eval = self.get_sample_without_replacement()

    def get_sample_without_replacement(self) -> torch.Tensor:
        return torch.rand(self.n_forest, self.n_rodt, device=self.device).argsort(-1)[:, :self.n_estimator]

    def forward(self, w, E) -> torch.Tensor:
        # w: (b, n_rodt, 1)
        # E: (b, n_rodt, n_hidden)
    
        sample_without_replacement = self.get_sample_without_replacement() if self.training else self.sample_without_replacement_eval

        w_prime = w[:, sample_without_replacement].softmax(-2) # (b, n_forest, n_rodt, 1)
        E_prime = E[:, sample_without_replacement].reshape(
            E.shape[0], self.n_forest, self.n_estimator, self.n_hidden
        ) # (b, n_forest, n_rodt, n_hidden)

        F = (w_prime * E_prime).sum(-2).reshape(
            E.shape[0], self.n_forest, self.n_hidden
        ) # (b, n_forest, n_hidden)
        return F


class rODTForestBagging(nn.Module):
    def __init__(self, n_hidden: int, dropout: float, n_class: int) -> None:
        super().__init__()
        self.phi_3 = nn.Sequential(
            nn.LayerNorm(n_hidden),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.LayerNorm(n_hidden),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, n_class)
        )

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        return self.phi_3(F) # (b, n_forest, n_class)


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scaling = self.head_dim ** -0.5
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.qw_proj = self.create_linear(embed_dim, embed_dim)
        self.kw_proj = self.create_linear(embed_dim, embed_dim)
        self.vw_proj = self.create_linear(embed_dim, embed_dim)
        self.ow_proj = self.create_linear(embed_dim, 1)

        self.qE_proj = self.create_linear(embed_dim, embed_dim)
        self.kE_proj = self.create_linear(embed_dim, embed_dim)
        self.vE_proj = self.create_linear(embed_dim, embed_dim)
        self.oE_proj = self.create_linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def create_linear(self, in_features: int, out_features: int):
        linear = nn.Linear(in_features, out_features)
        nn.init.xavier_uniform_(linear.weight, gain=1 / 2 ** 0.5)
        return linear

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        query: Tensor of shape (batch_size, tgt_len, embed_dim)
        key: Tensor of shape (batch_size, src_len, embed_dim)
        value: Tensor of shape (batch_size, src_len, embed_dim)
        attn_mask: Optional[Tensor] of shape (tgt_len, src_len) or (batch_size, tgt_len, src_len)
        """
        batch_size, tgt_len, _ = query.size()
        src_len = key.size(1)

        # Compute w
        qw = self.qw_proj(query)  # shape: (batch_size, tgt_len, embed_dim)
        kw = self.kw_proj(key)    # shape: (batch_size, src_len, embed_dim)
        vw = self.vw_proj(value)  # shape: (batch_size, src_len, 1)

        # Reshape into multihead format
        qw = qw.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, tgt_len, head_dim)
        kw = kw.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, head_dim)
        vw = vw.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, 1)

        attn_w = F.softmax(torch.matmul(qw, kw.transpose(-2, -1)) / self.scaling, dim=-1) # shape: (batch_size, num_heads, tgt_len, src_len)
        attn_w = self.dropout(attn_w) 
        attn_w_output = torch.matmul(attn_w, vw) # shape: (batch_size, num_heads, tgt_len, head_dim)

        # Reshape back to original dimensions
        attn_w_output = attn_w_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim) # shape: (batch_size, tgt_len, embed_dim)
        w_output = self.ow_proj(attn_w_output) # shape: (batch_size, tgt_len, 1)
        w_output = w_output.mean(1)

        # Compute w
        qE = self.qE_proj(query)  # shape: (batch_size, tgt_len, embed_dim)
        kE = self.kE_proj(key)    # shape: (batch_size, src_len, embed_dim)
        vE = self.vE_proj(value)  # shape: (batch_size, src_len, embed_dim)

        # Reshape into multihead format
        qE = qE.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, tgt_len, head_dim)
        kE = kE.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, head_dim)
        vE = vE.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # shape: (batch_size, num_heads, src_len, 1)

        attn_E = F.softmax(torch.matmul(qE, kE.transpose(-2, -1)) / self.scaling, dim=-1) # shape: (batch_size, num_heads, tgt_len, src_len)
        attn_E = self.dropout(attn_E) 
        attn_E_output = torch.matmul(attn_E, vE) # shape: (batch_size, num_heads, tgt_len, head_dim)

        # Reshape back to original dimensions
        attn_E_output = attn_E_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim) # shape: (batch_size, tgt_len, embed_dim)
        E_output = self.oE_proj(attn_E_output) # shape: (batch_size, tgt_len, embed_dim)
        E_output = E_output.mean(1)

        return w_output, E_output


class DOFEN(nn.Module):
    def __init__(
        self,
        category_column_count: List[int],
        n_class: int, 
        m: int = 16, 
        d: int = 4, 
        n_head: int = 1,
        n_forest: int = 100,
        n_hidden: int = 128,
        dropout: float = 0.0, 
        categorical_optimized: bool = False,
        fast_mode: int = 2048,
        use_bagging_loss: bool = False,
        device=torch.device('cuda'),
    ):
        super().__init__()

        self.device = device
        self.n_class = 1 if n_class == -1 else n_class
        self.is_rgr = True if n_class == -1 else False

        self.m = m
        self.d = d
        self.n_head = n_head
        self.n_forest = n_forest
        self.n_hidden = n_hidden
        self.dropout = dropout
        self.use_bagging_loss = use_bagging_loss

        self.n_cond = self.d * self.m
        self.n_col = len(category_column_count)
        self.n_rodt = self.n_cond * self.n_col // self.d
        self.n_estimator = max(2, int(self.n_col ** 0.5)) * self.n_cond // self.d

        self.condition_generation = ConditionGeneration(            
            category_column_count, 
            n_cond=self.n_cond,
            n_hidden=self.n_hidden,
            categorical_optimized=categorical_optimized, 
            fast_mode=fast_mode,
        )

        self.proj = nn.Linear(1, self.n_hidden)
        self.rodt_construction = rODTConstruction(
            self.n_cond,
            self.n_col,
            self.d,
        )

        self.norm = nn.LayerNorm(n_hidden)
        self.attn = MultiheadAttention(
            embed_dim=self.n_hidden,
            num_heads=n_head,
            dropout=0.2,
        )

        self.rodt_forest_construction = rODTForestConstruction(
            self.n_col, 
            self.n_rodt, 
            self.n_cond,
            self.n_estimator,
            n_head=self.n_head, 
            n_hidden=self.n_hidden,
            n_forest=self.n_forest,
            dropout=self.dropout,
            fast_mode=fast_mode,
            device=self.device
        )
        self.rodt_forest_bagging = rODTForestBagging(
            self.n_hidden,
            self.dropout,
            self.n_class
        )

        self.ffn_dropout_rate = 0.1
        self.act = ReGLU()

    def compute_loss(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def evaluate(self, X: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
        M = self.condition_generation(x)           # (b, n_cond, n_col, n_hidden)
        # M = M.unsqueeze(-1)                        # (b, n_cond, n_col, 1)
        # M = self.proj(M)                           # (b, n_cond, n_col, n_hidden)
        O = self.rodt_construction(M)              # (b, n_rodt, d, n_hidden)
        O = O.reshape(-1, self.d, self.n_hidden)   # (b * n_rodt, d, n_hidden)
        O = self.norm(O)
        w, E = self.attn(O, O, O)

        w = w.reshape(-1, self.n_rodt, 1)
        E = E.reshape(-1, self.n_rodt, self.n_hidden)

        E_norm = E / E.norm(dim=-1, keepdim=True)
        sim_loss = torch.einsum('bij,bjk->bik', E_norm, E_norm.permute(0, 2, 1))

        F = self.rodt_forest_construction(w, E) # (b, n_forest, n_hidden)
        y_hats = self.rodt_forest_bagging(F)    # (b, n_rodt, n_class)
        y_hat = y_hats.mean(1)                  # (b, n_class)

        if y is not None:
            loss = self.compute_loss(
                y_hats.permute(0, 2, 1) if not self.is_rgr else y_hats, 
                y.unsqueeze(-1).expand(-1, self.n_forest)
            ) #+  sim_loss.abs().mean()
            if self.n_forest > 1 and self.training and self.use_bagging_loss:
                loss += self.compute_loss(y_hat, y)
            return {'pred': y_hat, 'loss': loss}
        return {'pred': y_hat}

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)['pred']


class DOFENClassifier(BaseClassifier, DOFEN):
    def __init__(
        self,
        category_column_count: List[int],
        n_class: int, 
        m: int = 16,
        d: int = 4,
        n_head: int = 4,
        n_forest: int = 100,
        n_hidden: int = 16,
        dropout: float = 0.0, 
        categorical_optimized: bool = False,
        fast_mode: int = 2048,
        use_bagging_loss: bool = False,
        device=torch.device('cuda'),
    ) -> None:

        super().__init__(
            category_column_count=category_column_count, 
            n_class=n_class, 
            m=m, 
            d=d, 
            n_head=n_head,
            n_forest=n_forest,
            n_hidden=n_hidden,
            dropout=dropout, 
            categorical_optimized=categorical_optimized,
            fast_mode=fast_mode,
            use_bagging_loss=use_bagging_loss,
            device=device,
        )

In [34]:
args = Config()
set_random_seed()

data_dict = load_pkl_data(Path(args.data_dir, args.data_id, '0.pkl'))
n_class = data_dict['label_cat_count']
task = 'r' if n_class == -1 else 'c'

target_transform = args.target_transform if task == 'r' else False
train_X = data_dict['x_train']
train_y = data_dict['y_train' if not target_transform else 'y_train_transform']
valid_X = data_dict['x_val']
valid_y = data_dict['y_val' if not target_transform else 'y_val_transform']
test_X = data_dict['x_test']
test_y = data_dict['y_test' if not target_transform else 'y_test_transform']

data_args = {
    'n_feature': train_X.shape[1],
    'n_train': train_X.shape[0],
    'n_valid': valid_X.shape[0],
    'n_test': test_X.shape[0],
}

model_params = {
    'category_column_count': data_dict['col_cat_count'],
}
if n_class != -1:
    model_params['n_class'] = n_class


model = DOFENClassifier(**model_params)
eval_funs = cls_eval_funs
metrics = 'accuracy'

trainer = Trainer(
    model,
    batch_size=args.batch_size,
    n_epoch=args.n_epoch,
    eval_funs=eval_funs,
    metric=metrics,
    verbose=True,
)
trainer.fit(
    train_X=train_X, train_y=train_y,
    valid_X=valid_X, valid_y=valid_y,
    test_X=test_X, test_y=test_y,
)

  0%|          | 1/300 [00:01<08:54,  1.79s/it, selected_test_score=0.583]

{'epoch': 1, 'valid_accuracy': 0.58528, 'best_valid_epoch': 1, 'best_valid_score': 0.58528, 'test_accuracy': 0.58318, 'selected_test_score': 0.58318}


  1%|          | 2/300 [00:03<08:51,  1.78s/it, selected_test_score=0.583]

{'epoch': 2, 'valid_accuracy': 0.58652, 'best_valid_epoch': 2, 'best_valid_score': 0.58652, 'test_accuracy': 0.58254, 'selected_test_score': 0.58254}


  1%|          | 3/300 [00:05<08:57,  1.81s/it, selected_test_score=0.623]

{'epoch': 3, 'valid_accuracy': 0.62584, 'best_valid_epoch': 3, 'best_valid_score': 0.62584, 'test_accuracy': 0.6226, 'selected_test_score': 0.6226}


  1%|▏         | 4/300 [00:07<08:52,  1.80s/it, selected_test_score=0.632]

{'epoch': 4, 'valid_accuracy': 0.6338, 'best_valid_epoch': 4, 'best_valid_score': 0.6338, 'test_accuracy': 0.63182, 'selected_test_score': 0.63182}


  2%|▏         | 5/300 [00:08<08:48,  1.79s/it, selected_test_score=0.634]

{'epoch': 5, 'valid_accuracy': 0.63818, 'best_valid_epoch': 5, 'best_valid_score': 0.63818, 'test_accuracy': 0.63388, 'selected_test_score': 0.63388}


  2%|▏         | 6/300 [00:10<08:45,  1.79s/it, selected_test_score=0.654]

{'epoch': 6, 'valid_accuracy': 0.65716, 'best_valid_epoch': 6, 'best_valid_score': 0.65716, 'test_accuracy': 0.65448, 'selected_test_score': 0.65448}


  2%|▏         | 7/300 [00:12<08:49,  1.81s/it, selected_test_score=0.665]

{'epoch': 7, 'valid_accuracy': 0.6677, 'best_valid_epoch': 7, 'best_valid_score': 0.6677, 'test_accuracy': 0.66542, 'selected_test_score': 0.66542}


  3%|▎         | 8/300 [00:14<08:44,  1.80s/it, selected_test_score=0.674]

{'epoch': 8, 'valid_accuracy': 0.67396, 'best_valid_epoch': 8, 'best_valid_score': 0.67396, 'test_accuracy': 0.67352, 'selected_test_score': 0.67352}


  3%|▎         | 9/300 [00:16<08:41,  1.79s/it, selected_test_score=0.69] 

{'epoch': 9, 'valid_accuracy': 0.69098, 'best_valid_epoch': 9, 'best_valid_score': 0.69098, 'test_accuracy': 0.69034, 'selected_test_score': 0.69034}


  3%|▎         | 10/300 [00:17<08:43,  1.81s/it, selected_test_score=0.698]

{'epoch': 10, 'valid_accuracy': 0.70032, 'best_valid_epoch': 10, 'best_valid_score': 0.70032, 'test_accuracy': 0.69798, 'selected_test_score': 0.69798}


  4%|▎         | 11/300 [00:19<08:39,  1.80s/it, selected_test_score=0.703]

{'epoch': 11, 'valid_accuracy': 0.70426, 'best_valid_epoch': 11, 'best_valid_score': 0.70426, 'test_accuracy': 0.70304, 'selected_test_score': 0.70304}


  4%|▍         | 12/300 [00:21<08:36,  1.79s/it, selected_test_score=0.711]

{'epoch': 12, 'valid_accuracy': 0.71224, 'best_valid_epoch': 12, 'best_valid_score': 0.71224, 'test_accuracy': 0.7109, 'selected_test_score': 0.7109}


  4%|▍         | 13/300 [00:23<08:33,  1.79s/it, selected_test_score=0.711]

{'epoch': 13, 'valid_accuracy': 0.71186, 'best_valid_epoch': 12, 'best_valid_score': 0.71224, 'test_accuracy': 0.70906, 'selected_test_score': 0.7109}


  5%|▍         | 14/300 [00:25<08:35,  1.80s/it, selected_test_score=0.715]

{'epoch': 14, 'valid_accuracy': 0.71782, 'best_valid_epoch': 14, 'best_valid_score': 0.71782, 'test_accuracy': 0.71456, 'selected_test_score': 0.71456}


  5%|▌         | 15/300 [00:26<08:32,  1.80s/it, selected_test_score=0.718]

{'epoch': 15, 'valid_accuracy': 0.72146, 'best_valid_epoch': 15, 'best_valid_score': 0.72146, 'test_accuracy': 0.71816, 'selected_test_score': 0.71816}


  5%|▌         | 16/300 [00:28<08:28,  1.79s/it, selected_test_score=0.72] 

{'epoch': 16, 'valid_accuracy': 0.72358, 'best_valid_epoch': 16, 'best_valid_score': 0.72358, 'test_accuracy': 0.7202, 'selected_test_score': 0.7202}


  6%|▌         | 17/300 [00:30<08:30,  1.81s/it, selected_test_score=0.721]

{'epoch': 17, 'valid_accuracy': 0.72392, 'best_valid_epoch': 17, 'best_valid_score': 0.72392, 'test_accuracy': 0.72112, 'selected_test_score': 0.72112}


  6%|▌         | 18/300 [00:32<08:27,  1.80s/it, selected_test_score=0.724]

{'epoch': 18, 'valid_accuracy': 0.72708, 'best_valid_epoch': 18, 'best_valid_score': 0.72708, 'test_accuracy': 0.7238, 'selected_test_score': 0.7238}


  6%|▋         | 19/300 [00:34<08:23,  1.79s/it, selected_test_score=0.724]

{'epoch': 19, 'valid_accuracy': 0.7284, 'best_valid_epoch': 19, 'best_valid_score': 0.7284, 'test_accuracy': 0.72432, 'selected_test_score': 0.72432}


  7%|▋         | 20/300 [00:35<08:21,  1.79s/it, selected_test_score=0.726]

{'epoch': 20, 'valid_accuracy': 0.73004, 'best_valid_epoch': 20, 'best_valid_score': 0.73004, 'test_accuracy': 0.72556, 'selected_test_score': 0.72556}


  7%|▋         | 21/300 [00:37<08:23,  1.80s/it, selected_test_score=0.729]

{'epoch': 21, 'valid_accuracy': 0.7309, 'best_valid_epoch': 21, 'best_valid_score': 0.7309, 'test_accuracy': 0.72926, 'selected_test_score': 0.72926}


  7%|▋         | 22/300 [00:39<08:19,  1.80s/it, selected_test_score=0.731]

{'epoch': 22, 'valid_accuracy': 0.73154, 'best_valid_epoch': 22, 'best_valid_score': 0.73154, 'test_accuracy': 0.73074, 'selected_test_score': 0.73074}


  7%|▋         | 22/300 [00:41<08:41,  1.88s/it, selected_test_score=0.731]


KeyboardInterrupt: 

In [28]:
model

DOFENClassifier(
  (condition_generation): ConditionGeneration(
    (phi_1): ModuleDict(
      (num): Sequential(
        (0): Reshape()
        (1): FastGroupConv1d(10, 10240, kernel_size=(1,), stride=(1,), groups=10)
        (2): Reshape()
      )
    )
  )
  (proj): Linear(in_features=1, out_features=16, bias=True)
  (rodt_construction): rODTConstruction()
  (norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (attn): MultiheadAttention(
    (qw_proj): Linear(in_features=16, out_features=16, bias=True)
    (kw_proj): Linear(in_features=16, out_features=16, bias=True)
    (vw_proj): Linear(in_features=16, out_features=16, bias=True)
    (ow_proj): Linear(in_features=16, out_features=1, bias=True)
    (qE_proj): Linear(in_features=16, out_features=16, bias=True)
    (kE_proj): Linear(in_features=16, out_features=16, bias=True)
    (vE_proj): Linear(in_features=16, out_features=16, bias=True)
    (oE_proj): Linear(in_features=16, out_features=16, bias=True)
    (dropout): Dro

In [29]:
model.attn

MultiheadAttention(
  (qw_proj): Linear(in_features=16, out_features=16, bias=True)
  (kw_proj): Linear(in_features=16, out_features=16, bias=True)
  (vw_proj): Linear(in_features=16, out_features=16, bias=True)
  (ow_proj): Linear(in_features=16, out_features=1, bias=True)
  (qE_proj): Linear(in_features=16, out_features=16, bias=True)
  (kE_proj): Linear(in_features=16, out_features=16, bias=True)
  (vE_proj): Linear(in_features=16, out_features=16, bias=True)
  (oE_proj): Linear(in_features=16, out_features=16, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [30]:
model.cuda()(torch.tensor(train_X[:256]).cuda())

{'pred': tensor([[ 0.3960,  0.2966],
         [-0.3075,  0.4038],
         [ 0.2521, -0.0345],
         [-0.5294,  0.9460],
         [-0.6331,  0.6158],
         [-0.2786,  0.4342],
         [ 0.5308,  0.0683],
         [ 1.0910, -0.3817],
         [-0.5125,  0.6734],
         [ 0.2061,  0.2704],
         [-0.3352,  0.3928],
         [ 0.1079,  0.3883],
         [ 1.5631, -1.1562],
         [-0.7646,  0.6758],
         [ 2.2875, -1.7891],
         [-0.5398,  0.3726],
         [-0.1332,  0.5100],
         [ 0.7137, -0.0575],
         [-0.8001,  0.7812],
         [ 1.7682, -1.1256],
         [-0.9896,  0.5423],
         [ 1.9622, -1.7037],
         [-0.6393,  0.6541],
         [ 1.7233, -1.4720],
         [-0.8367,  0.9338],
         [ 0.2391,  0.3703],
         [-0.0406,  0.5239],
         [ 0.3454,  0.0383],
         [-0.8716,  0.7052],
         [-0.3691,  0.1509],
         [-0.9906,  0.8783],
         [ 0.3124,  0.1143],
         [-0.6862,  0.7564],
         [-0.5726,  0.5572],
      

In [17]:
activations = {}

def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

handle = model.get_submodule("attn.oE_proj").register_forward_hook(get_activation("oE_proj"))

In [18]:
model.cuda()(torch.tensor(train_X[:256]).cuda())

{'pred': tensor([[-0.1647,  0.6241],
         [-1.2888,  1.3635],
         [-0.7934,  0.9415],
         [-1.4424,  1.5319],
         [-1.0155,  1.0836],
         [-1.0390,  1.2782],
         [-0.2088,  0.7644],
         [-1.1447,  1.2297],
         [-1.1862,  1.3229],
         [-0.8219,  1.0915],
         [-0.5778,  0.8187],
         [-0.4195,  0.8935],
         [-0.2086,  0.7363],
         [-0.9380,  1.0292],
         [-0.3617,  0.8094],
         [-0.7664,  1.1210],
         [-0.0642,  0.6315],
         [-1.2997,  1.4078],
         [-1.4671,  1.5043],
         [ 0.2227,  0.3196],
         [-1.3820,  1.5328],
         [-0.4204,  0.8093],
         [-0.1210,  0.6234],
         [-0.6125,  0.9705],
         [-1.0423,  1.2774],
         [-0.5666,  0.9664],
         [ 0.0381,  0.4372],
         [-1.0460,  1.2743],
         [-0.6940,  0.9555],
         [-1.1396,  1.0931],
         [-0.1255,  0.5501],
         [-0.8774,  1.1176],
         [-0.2448,  0.7247],
         [-0.9834,  1.2011],
      

In [19]:
oE_proj_output = activations["oE_proj"]
print(oE_proj_output.shape)

torch.Size([40960, 4, 16])


In [20]:
oE_proj_output = oE_proj_output.reshape(256, -1, 4, 16)

In [21]:
oE_proj_output.shape

torch.Size([256, 160, 4, 16])

In [22]:
x = oE_proj_output[0].view(112, -1)

RuntimeError: shape '[112, -1]' is invalid for input of size 10240

In [23]:
x_ = x / ((x ** 2).sum(dim=1) ** 0.5)[:, None]
sim = x_ @ x_.T
sim

NameError: name 'x' is not defined

In [97]:
sim[0]

tensor([1.0000, 0.9864, 0.9966, 0.9950, 0.9987, 0.9916, 0.8117, 0.9820, 0.8812,
        0.9980, 0.8736, 0.9975, 0.9953, 0.7756, 0.8760, 0.9834, 0.8134, 0.9957,
        0.9969, 0.9995, 0.9883, 0.9851, 0.9917, 0.9916, 0.9787, 0.8838, 0.9921,
        0.9646, 0.9666, 0.9698, 0.9916, 0.9912, 0.9725, 0.8803, 0.8372, 0.9318,
        0.9742, 0.8234, 0.9980, 0.9429, 0.9838, 0.9895, 0.8261, 0.9684, 0.9800,
        0.7999, 0.8236, 0.7952, 0.7878, 0.9967, 0.9602, 0.9996, 0.9560, 0.8429,
        0.9823, 0.9934, 0.9196, 0.9720, 0.8384, 0.9974, 0.9914, 0.9196, 0.9915,
        0.9987, 0.9260, 0.9399, 0.9613, 0.9955, 0.9972, 0.9868, 0.8621, 0.9961,
        0.9988, 0.9552, 0.9926, 0.9950, 0.9754, 0.8788, 0.9974, 0.8764, 0.9973,
        0.8756, 0.7560, 0.9939, 0.9600, 0.9306, 0.9128, 0.9945, 0.9986, 0.9579,
        0.8050, 0.8734, 0.8130, 0.9218, 0.9924, 0.9981, 0.9805, 0.8486, 0.9330,
        0.9673, 0.8925, 0.9559, 0.9933, 0.9302, 0.8561, 0.8310, 0.8679, 0.9961,
        0.9639, 0.8331, 0.9967, 0.9446],