In [None]:
class ShotGenEncoder(nn.Module):
    def __init__(self, config, feature_name):
        super().__init__()
        self.feature_embedding = dict()
        # print(config['shot_num'], config['player_num'])
        feature_name = feature_name[:-3]

        for key in feature_name:
            if key == 'landing_x' or key == 'landing_y' or\
               key == 'time_diff' or key == 'shot_angle' or\
               key == 'distance':
                continue
            num = key + "_num"
            if key== 'type':
                num = 'shot' + "_num"
            self.feature_embedding[key] = Embedding(config[num], config['var_dim'])

        self.feature_embedding['area'] = nn.Linear(2, config['area_dim'])
        self.feature_embedding['time_diff'] = nn.Linear(1, config['area_dim'])
        self.feature_embedding['shot_angle'] = nn.Linear(1, config['area_dim'])
        self.feature_embedding['distance'] = nn.Linear(1, config['area_dim'])

        # self.shot_embedding = ShotEmbedding(config['shot_num'], config['shot_dim'])
        # self.player_embedding = PlayerEmbedding(config['player_num'], config['player_dim'])

        n_heads = 2
        d_k = config['encode_dim']
        d_v = config['encode_dim']
        d_model = config['encode_dim']
        d_inner = config['encode_dim'] * 2
        dropout = 0.1
        self.d_model = d_model

        self.position_embedding = PositionalEncoding(config['shot_dim'], config['encode_length'],
                                                     n_position=config['max_ball_round'])
        self.dropout = nn.Dropout(p=dropout)

        self.global_layer = EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout)
        self.local_layer = EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout)


    def forward(self, input_dict, src_mask=None, return_attns=False):
        enc_slf_attn_list = []
        area = torch.cat((input_dict['landing_x'].unsqueeze(-1), input_dict['landing_y'].unsqueeze(-1)), dim=-1).float()
        # area = torch.cat((input_x.unsqueeze(-1), input_y.unsqueeze(-1)), dim=-1).float()

        # embedded_area = F.relu(self.area_embedding(area))
        embedded_dict = dict()
        embedded_dict['area'] = F.relu(self.feature_embedding['area'](area))
        embedded_dict['time_diff'] = F.relu(self.feature_embedding['time_diff'](time_diff))
        embedded_dict['shot_angle'] = F.relu(self.feature_embedding['shot_angle'](shot_angle))
        embedded_dict['distance'] = F.relu(self.feature_embedding['distance'](distance))
        for key in input_dict.keys():            
            if key == 'landing_x' or key == 'landing_y' or\
               key == 'time_diff' or key == 'shot_angle' or\
               key == 'distance' or ley == 'area':
                continue
            embedded_dict[key] = self.feature_embedding[key](input_dict[key].to(torch.int64))
        # embedded_player = self.player_embedding(input_player)

        h_a = embedded_dict['area']
        h_s = embedded_dict['type']
        for key in embedded_dict.keys():
            if key == 'area' or key == 'type':
                continue
            h_a = h_a + embedded_dict[key]
            h_s = h_s + embedded_dict[key]
        # split player

        h_a_A = h_a[:, ::2]
        h_a_B = h_a[:, 1::2]
        h_s_A = h_s[:, ::2]
        h_s_B = h_s[:, 1::2]


        # local
        encode_output_area = self.dropout(self.position_embedding(h_a, mode='encode'))
        encode_output_shot = self.dropout(self.position_embedding(h_s, mode='encode'))

        # global
        encode_output_area_A = self.dropout(self.position_embedding(h_a_A, mode='encode'))
        encode_output_area_B = self.dropout(self.position_embedding(h_a_B, mode='encode'))
        encode_output_shot_A = self.dropout(self.position_embedding(h_s_A, mode='encode'))
        encode_output_shot_B = self.dropout(self.position_embedding(h_s_B, mode='encode'))

        encode_global_A, enc_slf_attn_A = self.global_layer(encode_output_area_A, encode_output_shot_A,
                                                            slf_attn_mask=src_mask)
        encode_global_B, enc_slf_attn_B = self.global_layer(encode_output_area_B, encode_output_shot_B,
                                                            slf_attn_mask=src_mask)
        encode_local_output, enc_slf_attn = self.local_layer(encode_output_area, encode_output_shot,
                                                             slf_attn_mask=src_mask)
        if return_attns:
            return encode_local_output, encode_global_A, encode_global_B, enc_slf_attn_list
        return encode_local_output, encode_global_A, encode_global_B


In [None]:
self.feature_selected = ['type', 'landing_x', 'landing_y', 'player',
                                 'score_diff','time_diff',
                                 'aroundhead', 'backhand', 'landing_height',
                                 'shot_angle','distance', 
                                 'x_distance','y_distance',
                                 'rally_id',
                                 'set', 'ball_round']