In [1]:
import os
import sys
import pymysql
import numpy as np
import pandas as pd
import re
import datetime
import category_encoders
import joblib

from Config import params_config, query_config, db_config
from Utils.bulk_insert import BulkInsert

import warnings
warnings.filterwarnings('ignore')

## fit_race_info_into_model.py

In [2]:
queries = query_config.queries
parameters = params_config.parameters
db_params = db_config.db_params
con = pymysql.connect(**db_params)

In [3]:
def fetchall_and_make_list_by(query, con):
    try:
        cursor = con.cursor()
        cursor.execute(query)
        fetch_result = cursor.fetchall()
        fetch_result_list = [item for item in fetch_result]
        cursor.close()
        return fetch_result_list
    except Exception as e:
        print(e)

def get_training_race_data_frame(queries, parameters, con):
    selected_query = queries['TRAINING_DATA_FROM_MASTER_PRIOR_RESULT']
    training_race_data_list = fetchall_and_make_list_by(selected_query, con)
    training_race_data_frame = pd.DataFrame(training_race_data_list, 
                                          columns=parameters['DATAFRAME_COL_NAMES']['training_race_data_cols'])
    return training_race_data_frame

In [4]:
training_race_df =  get_training_race_data_frame(queries, parameters, con)

In [5]:
training_race_df.shape

(481289, 31)

In [6]:
training_race_df.head()

Unnamed: 0,race_id,race_timing,race_title,race_weather,race_condition,course_syokin_list,post_position,horse_number,href_to_the_horse,horse_sex_age_in_result,...,trainer_name_in_result,trainer_name_in_prior,href_to_the_owner,breeder_name,jockey_finish_first_second,horse_number_finish_first_second,stallion_finish_first_second,conbi_finish_first_second,zensou_info_list,arrival_order
0,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,1,1,https://www.keibalab.jp/db/horse/2006104859/,牝3,...,[美]田中清隆,,,,,,,,,4
1,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,1,2,https://www.keibalab.jp/db/horse/2006105710/,牝3,...,[美]萱野浩二,,,,,,,,,12
2,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,2,3,https://www.keibalab.jp/db/horse/2006103271/,牝3,...,[美]武藤善則,,,,,,,,,1
3,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,2,4,https://www.keibalab.jp/db/horse/2006102356/,牝3,...,[美]武市康男,,,,,,,,,6
4,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,3,5,https://www.keibalab.jp/db/horse/2006102229/,牝3,...,[美]藤原辰雄,,,,,,,,,5


## Class: Preprocessing

### Features from master

In [None]:
def _get_year_month_day_from_race_timing(x):
    date_str = re.match('([0-9]+)/([0-9]+)/([0-9]+)' , x).group()
    year = datetime.datetime.strptime(date_str, '%Y/%m/%d').year
    month = datetime.datetime.strptime(date_str, '%Y/%m/%d').month
    day = datetime.datetime.strptime(date_str, '%Y/%m/%d').day
    return pd.Series([year, month, day])

def _get_dow_from_race_timing(x):
    return re.search("土|日" , x).group() 

def _encode_dow(df):
    dow_mapping = {'土': 1, '日': 2}
    return df['dow'].map(dow_mapping)    

def _get_time_in_the_racecourse_from_race_timing(x):
    return int(re.split('([0-9]+)回([ぁ-んァ-ン 一-龥]+)([0-9]+)日目' , x)[1])

def _get_racecourse_from_race_timing(x):
    return re.split('([0-9]+)回([ぁ-んァ-ン 一-龥]+)([0-9]+)日目' , x)[2]

def _get_what_day_in_the_racecourse_from_race_timing(x):
    return int(re.split('([0-9]+)回([ぁ-んァ-ン 一-龥]+)([0-9]+)日目' , x)[3])

def _encode_race_course(df):
    race_course_mapping = {'函館': 1, '札幌': 2, '福島': 3, '東京': 4, '中山': 5, '新潟': 6, '中京': 7, '阪神': 8, '京都': 9, '小倉': 10}
    return df['race_course'].map(race_course_mapping)

In [None]:
def preprocess_race_timing(df):
    df[['year', 'month', 'day']] = df['race_timing'].apply(_get_year_month_day_from_race_timing)
    df['dow'] = df['race_timing'].apply(_get_dow_from_race_timing)
    df['dow_encoded'] = _encode_dow(df)
    df['race_course'] =  df['race_timing'].apply(_get_racecourse_from_race_timing)
    df['race_course_encoded'] = _encode_race_course(df)
    df['time_in_racecourse'] =  df['race_timing'].apply(_get_time_in_the_racecourse_from_race_timing)
    df['what_day_in_racecourse'] =  df['race_timing'].apply(_get_what_day_in_the_racecourse_from_race_timing)
    return df

In [None]:
training_race_df = preprocess_race_timing(training_race_df)

In [None]:
def encode_race_weather(df):
    race_weather_mapping = {'晴': 1, '曇': 2, '小雨': 3, '雨': 4, '小雪': 5, '雪':6, 'unknown':7}
    return df['race_weather'].map(race_weather_mapping)

In [None]:
training_race_df['race_weather_encoded'] = encode_race_weather(training_race_df)

In [None]:
def encode_race_condition(df):
    race_condition_mapping = {'良': 1, '稍': 2, '重': 3, '不': 4, 'unknown':5}
    return df['race_condition'].map(race_condition_mapping)

In [None]:
training_race_df['race_condition_encoded'] = encode_race_condition(training_race_df)

In [None]:
def encode_fit_and_transform_href_to_the_horse(df):
    if parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_FOR_HORSE']=='TargetEncoder':
        ce = category_encoders.TargetEncoder(cols=['href_to_the_horse'])
    elif parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_FOR_HORSE']=='OrdinalEncoder':
        ce = category_encoders.OrdinalEncoder(cols=['href_to_the_horse'])
        
    ce.fit(df, 
           df[parameters['DATAFRAME_COL_NAMES']['target_col']],
           handle_unknown=parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_HANDLE_UNKNOWN'])
    joblib.dump(ce, parameters['FILE_NAME_OF_HORSE_CATEGORY_ENCODERS'])
    
    df_ce = ce.transform(df)
    df_ce = df_ce.rename(columns={'href_to_the_horse': 'href_to_the_horse_encoded'})
    return pd.concat([df, df_ce['href_to_the_horse_encoded']], axis=1)

In [None]:
# ce_loaded = joblib.load(parameters['FILE_NAME_OF_CATEGORY_ENCODERS'])
# ce_loaded

In [None]:
training_race_df = encode_fit_and_transform_href_to_the_horse(training_race_df)

### Features from prior or result

In [None]:
def _get_horse_age_and_sex_in_result(x):
    horse_sex = re.split('([ぁ-んァ-ン 一-龥]+)([0-9]+)' , x)[1]
    horse_age = int(re.split('([ぁ-んァ-ン 一-龥]+)([0-9]+)' , x)[2])
    return pd.Series([horse_sex, horse_age])

def  _encode_horse_sex(df_about_horse_sex):
    horse_sex_mapping = {'牡': 1, '牝': 2, 'セ': 3}
    return df_about_horse_sex.map(horse_sex_mapping)

def preprocess_horse_sex_age(df, target_cols_type):
    if target_cols_type == 'result':
        df[['horse_sex', 'horse_age']] = df['horse_sex_age_in_result'].apply(_get_horse_age_and_sex_in_result)
        df['horse_sex_encoded'] = _encode_horse_sex(df['horse_sex'])
    elif target_cols_type == 'prior':
        df['horse_age'] = pd.to_numeric(training_race_df["horse_age_in_prior"], errors='coerce')
        df['horse_sex_encoded'] = _encode_horse_sex(df['horse_sex_in_prior'])      
    return df

In [None]:
training_race_df = preprocess_horse_sex_age(df=training_race_df, target_cols_type='result')

In [None]:
# training_race_df = training_race_df[training_race_df['horse_weight_in_result']!='計不(---)']

In [None]:
def _parse_horse_weight_increment(x):
    return int(x.replace('＋', '+').replace('－', '-').replace('---', '0'))

def _get_horse_weight_info_in_result(x):
    horse_weight = int(re.split('(\()(.*)(\))' , x)[0])
    horse_weight_increment_str = re.split('(\()(.*)(\))' , x)[2]
    horse_weight_increment = _parse_horse_weight_increment(horse_weight_increment_str)
    return pd.Series([horse_weight, horse_weight_increment])

def _get_horse_weight_in_prior(x):
    try:
        return int(re.search("[0-9]+" , x).group())
    except TypeError:
        return np.nan

def _get_horse_weight_increment_in_prior(x):
    try:
        horse_weight_increment_str = re.split('(\()(.*)(kg\))' , x)[2]
        horse_weight_increment = _parse_horse_weight_increment(horse_weight_increment_str)
        return horse_weight_increment
    except TypeError:
        return np.nan

def preprocess_horse_weight_and_increment(df, target_cols_type):
    if target_cols_type == 'result':
        df[['horse_weight', 'horse_weight_increment']] = df['horse_weight_in_result'].apply(_get_horse_weight_info_in_result)
    elif target_cols_type == 'prior':
        df['horse_weight'] = df['horse_weight_in_prior'].apply(_get_horse_weight_in_prior)
        df['horse_weight_increment'] = df['horse_weight_increment_in_prior'].apply(_get_horse_weight_increment_in_prior)
    return df

In [None]:
training_race_df = preprocess_horse_weight_and_increment(df=training_race_df, target_cols_type='result')

In [None]:
def _get_and_encode_weight_loss_flg(x):
    try:
        weight_loss_flg = re.search('▲|△|☆' , x).group()
        weight_loss_encode = int(weight_loss_flg.replace('▲', '3').replace('△', '2').replace('☆', '1'))
    except AttributeError:
        weight_loss_encode = 0
    return weight_loss_encode

def _get_horse_impost_in_prior(x):
    try:
        return float(re.split('(▲|△|☆|.)(.*)(\()(.*)(\))(.*)' , x)[4])
    except TypeError:
        return np.nan

def _get_weight_loss_encode_in_prior(x):
    try:
        weight_loss_flg_str = re.split('(▲|△|☆|.)(.*)(\()(.*)(\))(.*)' , x)[1]
        return _get_and_encode_weight_loss_flg(weight_loss_flg_str)
    except TypeError:
        return np.nan

def preprocess_jockey_name(df, target_cols_type):
    if target_cols_type == 'result':
        df['horse_impost'] = df['horse_impost_in_result']
        df['weight_loss_encode'] = df['jockey_name_in_result'].apply(_get_and_encode_weight_loss_flg)
    elif target_cols_type == 'prior':
        df['horse_impost'] = df['jockey_name_and_horse_impost_in_prior'].apply(_get_horse_impost_in_prior)
        df['weight_loss_encode'] = df['jockey_name_and_horse_impost_in_prior'].apply(_get_weight_loss_encode_in_prior)
    return df

In [None]:
training_race_df = preprocess_jockey_name(df=training_race_df, target_cols_type='result')

In [None]:
def encode_fit_and_transform_href_to_the_jockey(df):
    if parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_FOR_JOCKEY']=='TargetEncoder':
        ce = category_encoders.TargetEncoder(cols=['href_to_the_jockey'])
    elif parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_FOR_JOCKEY']=='OrdinalEncoder':
        ce = category_encoders.OrdinalEncoder(cols=['href_to_the_jockey'])
        
    ce.fit(df, 
           df[parameters['DATAFRAME_COL_NAMES']['target_col']],
           handle_unknown=parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_HANDLE_UNKNOWN'])
    joblib.dump(ce, parameters['FILE_NAME_OF_JOCKEY_CATEGORY_ENCODERS'])
    
    df_ce = ce.transform(df)
    df_ce = df_ce.rename(columns={'href_to_the_jockey': 'href_to_the_jockey_encoded'})
    return pd.concat([df, df_ce['href_to_the_jockey_encoded']], axis=1)

In [None]:
# ce_loaded = joblib.load(parameters['FILE_NAME_OF_JOCKEY_CATEGORY_ENCODERS'])
# ce_loaded

In [None]:
training_race_df = encode_fit_and_transform_href_to_the_jockey(training_race_df)

In [None]:
def _get_trainer_belonging_in_result(x):
    return re.split('\[(.*)\]' , x)[1]

def _get_trainer_belonging_in_prior(x):
    try:
        return re.split('(.*)(・)(.*)' , x)[1]
    except TypeError:
        return np.nan

def _encode_trainer_belonging(df):
    trainer_belonging_mapping = {'美': 1, '栗': 2, '招': 3}
    return df['trainer_belonging'].map(trainer_belonging_mapping)

def preprocess_trainer_name(df, target_cols_type):
    if target_cols_type == 'result':
        df['trainer_belonging'] = df['trainer_name_in_result'].apply(_get_trainer_belonging_in_result)
        df['trainer_belonging_encoded'] = _encode_trainer_belonging(df)
    elif target_cols_type == 'prior':
        df['trainer_belonging'] = df['trainer_name_in_prior'].apply(_get_trainer_belonging_in_prior)
        df['trainer_belonging_encoded'] = _encode_trainer_belonging(df)
    return df

In [None]:
training_race_df = preprocess_trainer_name(df=training_race_df, target_cols_type='result')

In [None]:
# def encode_fit_and_transform_href_to_the_trainer(df):
#     if parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_FOR_TRAINER']=='TargetEncoder':
#         ce = category_encoders.TargetEncoder(cols=['href_to_the_trainer'])
#     elif parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_FOR_TRAINER']=='OrdinalEncoder':
#         ce = category_encoders.OrdinalEncoder(cols=['href_to_the_trainer'])
        
#     ce.fit(df, 
#            df[parameters['DATAFRAME_COL_NAMES']['target_col']],
#            handle_unknown=parameters['HYPER_PARAMETERS']['CATEGORY_ENCODERS_HANDLE_UNKNOWN'])
#     joblib.dump(ce, parameters['FILE_NAME_OF_TRAINER_CATEGORY_ENCODERS'])
    
#     df_ce = ce.transform(df)
#     df_ce = df_ce.rename(columns={'href_to_the_trainer': 'href_to_the_trainer_encoded'})
#     return pd.concat([df, df_ce['href_to_the_trainer_encoded']], axis=1)

In [None]:
# training_race_df = encode_fit_and_transform_href_to_the_trainer(training_race_df)

In [None]:
training_race_df.head()

In [None]:
training_race_df.tail()

## Check wether Preprocess Class works

In [7]:
from Model.Preprocessing import Preprocessing

In [8]:
pp = Preprocessing(parameters)

In [9]:
def preprocess_result_data_based_training_race_df(df, pp):
    df = pp.preprocess_race_timing(df=df)
    df = pp.encode_race_weather(df=df)
    df = pp.encode_race_condition(df=df)
    df = pp.encode_fit_and_transform_href_to_the_horse(df=df)
    df = pp.preprocess_horse_sex_age(df=df, target_cols_type='result')
    df = pp.preprocess_horse_weight_and_increment(df=df, target_cols_type='result')
    df = pp.preprocess_jockey_name(df=df, target_cols_type='result')
    df = pp.encode_fit_and_transform_href_to_the_jockey(df=df)
    df = pp.preprocess_trainer_name(df=df, target_cols_type='result')
    df = pp.preprocess_arrival_order(df=df)
    return df

In [10]:
training_race_df_preprocessed = preprocess_result_data_based_training_race_df(training_race_df, pp)

In [11]:
training_race_df_preprocessed.head()

Unnamed: 0,race_id,race_timing,race_title,race_weather,race_condition,course_syokin_list,post_position,horse_number,href_to_the_horse,horse_sex_age_in_result,...,horse_age,horse_sex_encoded,horse_weight,horse_weight_increment,horse_impost,weight_loss_encode,href_to_the_jockey_encoded,trainer_belonging,trainer_belonging_encoded,arrival_order_category
0,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,1,1,https://www.keibalab.jp/db/horse/2006104859/,牝3,...,3,2,436,-2,54.0,0,7.571491,美,1,4
1,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,1,2,https://www.keibalab.jp/db/horse/2006105710/,牝3,...,3,2,472,0,54.0,0,9.916139,美,1,4
2,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,2,3,https://www.keibalab.jp/db/horse/2006103271/,牝3,...,3,2,398,0,52.0,2,9.119874,美,1,1
3,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,2,4,https://www.keibalab.jp/db/horse/2006102356/,牝3,...,3,2,464,6,54.0,0,9.5212,美,1,4
4,200906070501,2009/6/7(日) 3回東京6日目,サラ系3歳未勝利,晴,重,サラ系3歳未勝利 牝 [指] 馬齢 ダ1400m 16頭 10:00発走 本賞金 500万 ...,3,5,https://www.keibalab.jp/db/horse/2006102229/,牝3,...,3,2,472,0,51.0,3,7.593424,美,1,4


## Modeling Process

In [86]:
# parameters
parameters['CRITERIA_FOR_SPLIT_TRAINING_DATA'] ={'year': 2019, 'month': 5}

def make_dataset_to_model_fit(df):
    train_df = df[(df['year']<parameters['CRITERIA_FOR_SPLIT_TRAINING_DATA']['year']) | (df['month']<parameters['CRITERIA_FOR_SPLIT_TRAINING_DATA']['month'])]
    validataion_df = df[(df['year']>=parameters['CRITERIA_FOR_SPLIT_TRAINING_DATA']['year']) & (df['month']>=parameters['CRITERIA_FOR_SPLIT_TRAINING_DATA']['month'])]
    
    x_train_df = train_df[parameters['DATAFRAME_COL_NAMES']['feature_cols_part1']]
    y_train_df = train_df[parameters['DATAFRAME_COL_NAMES']['target_col']]
    x_valid_df = validataion_df[parameters['DATAFRAME_COL_NAMES']['feature_cols_part1']]
    y_valid_df = validataion_df[parameters['DATAFRAME_COL_NAMES']['target_col']]
    
    return x_train_df, y_train_df, x_valid_df, y_valid_df

In [87]:
x_train_df, y_train_df, x_valid_df, y_valid_df = make_dataset_to_model_fit(df=training_race_df_preprocessed)

In [88]:
print(x_train_df.shape)
print(y_train_df.shape)

print(x_valid_df.shape)
print(y_valid_df.shape)

(469765, 22)
(469765,)
(11524, 22)
(11524,)


In [89]:
y_train_df.groupby(y_train_df.values).count()

1     32978
2     32970
3     32964
4    370853
Name: arrival_order_category, dtype: int64

In [90]:
y_valid_df.groupby(y_valid_df.values).count()

1     861
2     861
3     862
4    8940
Name: arrival_order_category, dtype: int64

In [112]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

from sklearn.metrics import classification_report
from scipy.stats import randint as sp_randint

In [114]:
parameters["HYPER_PARAMETERS"]['RF_CLF'] = {
    'CV_WAYS': 'GridSearchCV'  # 'GridSearchCV', 'RandomizedSearchCV'
    'GS_PARAMS': {'n_estimators': [10, 50, 100], 
                                 'max_depth': [5, 10, 20], 
                                 'max_features': ['sqrt', 'log2', None],
                                 'class_weight': ['balanced', None]},
    'RS_PARAMS': {'n_estimators': sp_randint(10, 100), 
                                 'max_depth': sp_randint(5, 20), 
                                 'max_features': ['sqrt', 'log2', None],
                                 'class_weight': ['balanced', None]}
}

In [96]:
rf_clf = RandomForestClassifier(random_state=0)
rf_clf.fit(x_train_df, y_train_df)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=None, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=10,
                       n_jobs=None, oob_score=False, random_state=0, verbose=0,
                       warm_start=False)

In [105]:
y_valid_pred = rf_clf.predict(x_valid_df)
# pd.Series(y_valid_pred).groupby(pd.Series(y_valid_pred).values).count()

In [109]:
print(classification_report(y_valid_df, y_valid_pred))

              precision    recall  f1-score   support

           1       0.30      0.24      0.27       861
           2       0.20      0.12      0.15       861
           3       0.18      0.05      0.08       862
           4       0.83      0.94      0.88      8940

    accuracy                           0.76     11524
   macro avg       0.38      0.34      0.34     11524
weighted avg       0.69      0.76      0.72     11524



In [None]:
# rscv = RandomizedSearchCV(estimator=RandomForestClassifier(random_state=0),
#                                     param_distributions=parameters["HYPER_PARAMETERS"]['RF_CLF']['GS_PARAMS'],
#                                     n_iter=54,
#                                     scoring="f1_weighted",
#                                     cv=3,
#                                     verbose=1,
#                                     n_jobs=-1,          
#                                     random_state=1)

In [115]:
gscv = GridSearchCV(estimator=RandomForestClassifier(random_state=0),
                    param_grid=parameters["HYPER_PARAMETERS"]['RF_CLF']['GS_PARAMS'],
                    scoring="f1_weighted",
                    cv=3,
                    verbose=1,
                    n_jobs=-1) 

In [None]:
gscv.fit(x_train_df, y_train_df)

In [None]:
gscv..best_estimator_

In [None]:
gscv.best_params_

In [None]:
rf_clf = RandomForestClassifier(random_state=0,
                               n_estimators=gscv.best_params_[],
                               max_depth=gscv.best_params_[],
                               max_features=gscv.best_params_[],
                               class_weight=gscv.best_params_[])
rf_clf.fit(x_train_df, y_train_df)

### Try Learning to Rank

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class Net(nn.Module):
    def __init__(self, D):
        super(Net, self).__init__()
        self.l1 = nn.Linear(D, 10)
        self.l2 = nn.Linear(10, 1)

    def forward(self, x):
        x = torch.sigmoid(self.l1(x))
        x = self.l2(x)
        return x

In [None]:
def listnet_loss(y_i, z_i):
    """
    y_i: (n_i, 1)
    z_i: (n_i, 1)
    """

    P_y_i = F.softmax(y_i, dim=0)
    P_z_i = F.softmax(z_i, dim=0)
    return - torch.sum(P_y_i * torch.log(P_z_i))

def make_dataset(N_train, N_valid, D):
    ws = torch.randn(D, 1)

    X_train = torch.randn(N_train, D, requires_grad=True)
    X_valid = torch.randn(N_valid, D, requires_grad=True)

    ys_train_score = torch.mm(X_train, ws)
    ys_valid_score = torch.mm(X_valid, ws)

    bins = [-2, -1, 0, 1]  # 5 relevances
    ys_train_rel = torch.Tensor(
        np.digitize(ys_train_score.clone().detach().numpy(), bins=bins)
    )
    ys_valid_rel = torch.Tensor(
        np.digitize(ys_valid_score.clone().detach().numpy(), bins=bins)
    )

    return X_train, X_valid, ys_train_rel, ys_valid_rel


def swapped_pairs(ys_pred, ys_target):
    N = ys_target.shape[0]
    swapped = 0
    for i in range(N - 1):
        for j in range(i + 1, N):
            if ys_target[i] < ys_target[j]:
                if ys_pred[i] > ys_pred[j]:
                    swapped += 1
            elif ys_target[i] > ys_target[j]:
                if ys_pred[i] < ys_pred[j]:
                    swapped += 1
    return swapped


def ndcg(ys_true, ys_pred):
    def dcg(ys_true, ys_pred):
        _, argsort = torch.sort(ys_pred, descending=True, dim=0)
        ys_true_sorted = ys_true[argsort]
        ret = 0
        for i, l in enumerate(ys_true_sorted, 1):
            ret += (2 ** l - 1) / np.log2(1 + i)
        return ret
    ideal_dcg = dcg(ys_true, ys_true)
    pred_dcg = dcg(ys_true, ys_pred)
    return pred_dcg / ideal_dcg

In [None]:
N_train = 500
N_valid = 100
D = 50
epochs = 10
batch_size = 16

X_train, X_valid, ys_train, ys_valid = make_dataset(N_train, N_valid, D)

In [None]:
net = Net(D)
opt = optim.Adam(net.parameters())

In [None]:
epoch = 0

In [None]:
idx = torch.randperm(N_train)

X_train = X_train[idx]
ys_train = ys_train[idx]

cur_batch = 0

In [None]:
it = 0

In [None]:
batch_X = X_train[cur_batch: cur_batch + batch_size]
batch_ys = ys_train[cur_batch: cur_batch + batch_size]
cur_batch += batch_size

In [None]:
batch_X.shape

In [None]:
opt.zero_grad()

In [None]:
opt

In [None]:
batch_pred = net(batch_X)
batch_pred

In [None]:
batch_ys

In [None]:
batch_loss = listnet_loss(batch_ys, batch_pred)
batch_loss

In [None]:
batch_loss.backward(retain_graph=True)

In [None]:
opt.step()

In [None]:
opt