In [89]:
#!g1.1
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import scipy
import plotly.express as px

import xgboost as xgb
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

from catboost import CatBoostClassifier, Pool


In [90]:
#!g1.1
test = pd.read_csv('1sentencenewtest.csv', index_col=0)

In [91]:
#!g1.1
embeddings_semantic = pd.read_csv('sen_emb.csv', index_col=0)

In [92]:
#!g1.1
# Dictionary mapping
dict_label = {0: '-', 1: '?', 2: '+'}

# map the values in the column using the dictionary
embeddings_semantic['target'] = embeddings_semantic['target'].map(dict_label)



In [93]:
#!g1.1
X_emb_semantic = embeddings_semantic.drop('target',axis=1)
y_emb_semantic = embeddings_semantic['target']

In [95]:
#!g1.1
embeddings_categories = pd.read_csv('ok_categ_emb_train_final_ok.csv', index_col=0)

In [96]:
#!g1.1
embeddings_categories

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,985,986,987,988,989,990,991,992,993,994,995,996,997,998,999,1000,1001,1002,1003,1004,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023,target
0,-0.197150,-0.492733,0.242752,0.276384,-0.341769,0.649663,0.436415,1.028297,0.759479,-0.037235,2.599360,-0.774528,-0.070861,-0.533864,0.208983,0.597830,0.014157,-0.097436,-0.319961,-0.115878,-0.649831,1.260697,0.245076,-0.119573,0.026702,0.426941,1.193080,-0.967740,-1.475101,-0.509235,0.723489,-0.086194,-0.839914,-0.446740,-1.913244,-0.289067,-0.963961,0.814593,0.926772,-0.434793,...,1.195466,-0.361464,0.633875,-0.866637,0.799471,1.045184,0.368548,-1.424376,0.654992,-0.024584,1.231598,0.900147,2.120381,-0.905322,-0.063521,0.968733,-2.054625,0.708550,0.091177,0.217109,-0.556910,-0.577859,1.293057,-0.120326,1.450834,-0.973732,0.048671,0.688070,-0.337478,0.290804,0.565368,0.123341,0.738762,0.165799,0.803596,-0.296038,0.106892,1.662816,-0.117572,0
1,-0.296503,-1.216333,1.987220,0.569431,-0.046948,0.461104,0.298227,0.633410,1.130803,0.073801,2.258140,0.071064,0.367049,-0.502735,0.012630,0.561093,-0.205589,-0.766349,-0.945479,-0.626539,-0.664176,0.343784,1.277935,-0.658621,-1.359564,0.801554,-0.688164,-0.215484,0.476578,0.081050,0.775126,0.192444,-1.824652,-0.679714,-0.570320,-1.662270,-1.181905,1.233421,1.264541,-1.427104,...,-0.538562,-0.082593,0.167374,-1.204939,-0.129600,1.481471,0.391387,-0.551982,0.046777,-0.580550,0.900929,0.273328,1.811638,-0.942136,-0.443595,1.398365,-1.689520,0.293095,-1.952150,0.854172,-0.221443,-0.436292,1.131780,-0.476253,1.847384,-2.223216,1.268524,-0.517329,-0.754903,0.507452,0.719096,0.539857,0.753479,0.624003,-0.083083,1.466187,-0.607539,0.729190,0.067639,0
2,-0.455742,-0.755382,2.422975,0.845510,-0.805267,0.336262,0.373045,0.768055,0.431779,-0.641861,2.248143,-0.031287,0.548020,-0.028080,-0.328902,0.612601,-0.423368,-0.899131,-0.614806,-0.572159,-0.720578,1.132529,1.642390,-0.507976,-1.161364,0.892079,-0.290592,-0.737740,0.486198,-0.296975,0.405166,0.330061,-1.975672,-0.070250,-0.813980,-1.344733,-0.810822,0.653140,1.400097,-1.611629,...,-0.081948,-0.313064,-0.171829,-0.742366,-0.079756,1.152810,0.101971,-1.190078,0.381606,-0.380522,0.620796,0.581781,0.753662,-0.568913,0.172460,1.031953,-0.857331,0.494861,-1.682483,0.471346,0.044041,0.719867,1.243587,-0.350345,2.076875,-1.979518,1.170952,-0.413533,-0.296651,0.390602,0.515355,0.262573,1.408943,1.750823,-0.072048,0.656361,0.093258,0.314159,0.106662,0
3,-0.668212,0.672908,0.571884,0.840167,0.428761,0.543415,0.067167,0.790246,0.060164,0.736906,2.479573,-1.295207,0.068386,-1.071007,0.056986,0.670980,0.279546,0.372412,-0.602815,-0.592236,-0.540556,1.490815,0.598613,-0.684298,-0.356956,0.141763,0.805051,0.081773,0.259099,-0.096878,0.001114,1.728352,-0.884207,0.070567,-1.334123,-0.749713,0.077281,0.042504,1.068010,-0.891244,...,0.676975,-0.274169,0.566250,-1.110650,-0.083266,0.706527,0.165677,-1.319869,1.014560,-0.209731,0.551011,-0.144565,1.465563,-0.576977,0.228745,1.019863,-1.134535,0.478835,0.716006,0.311336,-0.556460,0.227912,0.680366,0.895348,1.790665,-0.560592,-0.008492,0.162080,-0.339853,-0.804344,0.738334,0.374959,0.579484,0.899126,1.006058,-0.462675,-0.096167,0.956857,-0.478446,0
4,-0.168889,-0.092882,0.847053,0.674367,-0.025841,0.611089,0.480568,0.818037,0.602498,-0.522791,2.287829,-0.743721,-0.721283,-0.277202,0.102453,0.407518,-0.449211,-0.081354,-0.730308,0.329423,-0.642107,1.147822,1.060459,-0.716144,-0.644513,0.393206,0.376873,-0.232740,-0.275478,-0.312748,0.693516,0.820357,-0.954473,-0.142005,-1.461858,-1.592135,-0.623946,0.624952,1.322435,-0.741527,...,0.751057,-0.576247,0.543056,-0.370677,-0.147288,1.837498,0.938366,-1.284908,0.702035,0.287762,0.860459,0.146761,2.341014,-0.323227,0.322070,1.102908,-1.195311,0.644721,-0.644211,0.023760,-0.907322,0.182716,1.237586,0.429132,1.620715,-1.317639,0.468526,-0.116301,-0.597784,0.279748,0.737442,0.270749,1.139555,0.998746,0.201250,0.677216,-0.197071,0.624862,0.009942,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13430,-0.353714,-0.223488,1.686368,0.891664,0.786180,0.875882,0.600786,0.983942,-0.095056,-0.114852,2.705800,-0.327013,-0.566767,-0.330950,-0.141602,0.562904,-0.305978,-0.347259,0.299850,-0.196189,-0.407356,0.776028,1.855242,-0.779697,-0.807154,0.296529,0.037179,-0.059086,0.044585,-0.182046,0.917416,1.067814,-1.360695,-0.516148,-1.097545,-1.859179,-0.545107,0.387442,1.808913,-2.349185,...,0.342987,-0.233487,0.602466,-1.133934,-0.272395,1.104759,0.607770,-1.175140,0.412642,-0.511790,0.379046,-0.222514,2.560421,-0.553303,0.508328,0.718235,-1.979393,0.332877,-0.933917,0.060639,-1.052874,0.498357,1.640089,0.219139,0.641003,-1.628707,0.191459,-0.447912,-0.082249,0.105909,0.968914,-0.631415,1.352886,0.454169,0.365777,1.010790,0.304535,0.143226,-0.091149,0
13431,0.351023,-0.203113,1.685679,0.039722,0.849022,0.498020,-0.671521,1.262892,0.130761,-0.201025,1.931395,0.211061,0.120855,-1.265413,-1.498230,0.755736,-1.046076,0.297588,1.011198,0.293241,0.381936,0.988524,0.465505,-0.579563,-0.740599,1.323760,-0.267593,0.277886,-0.848724,-0.980781,0.024949,1.055314,0.123415,-2.044379,0.114929,-1.322839,0.030363,-0.297329,0.064890,-1.557853,...,0.948259,0.858692,0.440549,-0.677409,-0.159740,1.774594,0.845124,-1.450790,-1.054888,-0.273919,-0.195623,0.838616,2.194146,-0.599141,1.148862,0.604915,-0.088258,-0.644520,-1.003095,-0.152390,-0.472729,0.807565,1.463213,2.156046,-0.341006,-1.159779,-0.733229,-0.620816,-0.076057,0.474849,-1.820013,-0.203112,0.910340,0.218210,0.601231,0.396834,-0.455897,0.190515,-0.119024,0
13432,0.217026,0.109330,0.383601,0.154830,0.488522,0.736103,1.294642,0.956266,0.526242,0.019604,2.553675,-0.760915,-0.429451,-0.564259,1.113047,0.849114,-0.114471,-0.295879,-0.485114,-0.145850,-0.415601,1.048809,0.616662,-0.076821,-0.253384,0.767209,0.455565,-0.095504,-0.574366,-0.063631,1.226524,1.004067,-0.648592,-0.661089,-0.870968,-0.665141,-0.556811,1.084855,1.096224,-1.200886,...,1.108422,-0.529252,1.187338,-0.492637,0.479583,1.564389,1.084642,-1.368438,0.391909,-0.014615,1.275730,0.285374,1.783651,-0.062758,-0.136548,0.469895,-1.438281,-0.003025,0.069650,0.042711,-0.897889,0.120552,1.375450,1.210344,1.669116,-1.800851,0.016816,0.107371,-0.222188,0.125138,0.158322,0.534504,0.568531,0.860508,0.672625,-0.023312,-0.514286,0.544672,0.651606,2
13433,-0.149132,-0.258929,1.956033,0.457178,-0.143329,0.634107,1.652280,0.976143,0.308342,-1.163448,2.059692,0.002985,0.459437,-0.286671,-0.810266,0.632554,0.165783,0.047033,-0.034022,-0.042977,-0.246012,0.235623,1.899299,-0.515237,-2.015499,0.014297,-0.205125,-0.324690,0.902470,-0.092142,0.238834,0.615287,-2.422139,-0.237405,-0.650236,-1.189852,-1.026486,0.413475,1.324875,-1.283207,...,-0.095755,-0.797786,-0.071407,-1.322971,-0.555721,1.108873,0.325867,-0.430457,0.656033,-0.568165,-0.522156,0.513207,1.279995,-0.543969,0.710175,0.333698,-1.267847,-0.110818,-1.578852,0.473709,-0.498924,0.889598,1.036346,-0.367827,1.500100,-1.900548,1.016101,-0.137424,-0.326034,0.399090,1.209888,-0.055077,1.132390,0.725480,-0.188348,1.107403,-0.114463,0.493490,0.288932,0


In [97]:
#!g1.1
dict_label_categories = {0:'Communication', 1:'Price', 2:'Quality', 3:'Safety'}
embeddings_categories['target'] = embeddings_categories['target'].map(dict_label_categories)

In [98]:
#!g1.1
X_emb_categories = embeddings_categories.drop('target', axis=1)
y_emb_categories = embeddings_categories['target']

In [100]:
#!g1.1
cb_semantic = CatBoostClassifier(
                task_type='GPU',
                devices='0:1',
                loss_function='MultiClass',
                eval_metric='AUC',
                learning_rate=0.008,
                iterations=4200,
                depth=10,
                verbose=100,
                l2_leaf_reg=4)

cb_categories = CatBoostClassifier(
                    task_type='GPU',
                    devices='0:1',
                    loss_function='MultiClass',
                    eval_metric='AUC',
                    learning_rate=0.013,
                    iterations=2500,
                    depth=10,
                    verbose=100,
                    l2_leaf_reg=4)

In [103]:
#!g1.1
# Обучение для семантики
cb_semantic.fit(X_emb_semantic, y_emb_semantic)

  self._init_pool(data, label, cat_features, text_features, embedding_features, pairs, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, feature_names, thread_count)
AUC is not implemented on GPU. Will use CPU for metric computation, this could significantly affect learning time


0:	total: 90.5ms	remaining: 6m 19s
100:	total: 7.89s	remaining: 5m 20s
200:	total: 15.7s	remaining: 5m 13s
300:	total: 23.6s	remaining: 5m 6s
400:	total: 31.4s	remaining: 4m 57s
500:	total: 38.8s	remaining: 4m 46s
600:	total: 46s	remaining: 4m 35s
700:	total: 53.1s	remaining: 4m 24s
800:	total: 60s	remaining: 4m 14s
900:	total: 1m 6s	remaining: 4m 4s
1000:	total: 1m 13s	remaining: 3m 55s
1100:	total: 1m 20s	remaining: 3m 46s
1200:	total: 1m 27s	remaining: 3m 37s
1300:	total: 1m 33s	remaining: 3m 28s
1400:	total: 1m 40s	remaining: 3m 20s
1500:	total: 1m 46s	remaining: 3m 11s
1600:	total: 1m 53s	remaining: 3m 3s
1700:	total: 1m 59s	remaining: 2m 55s
1800:	total: 2m 6s	remaining: 2m 47s
1900:	total: 2m 12s	remaining: 2m 40s
2000:	total: 2m 18s	remaining: 2m 32s
2100:	total: 2m 25s	remaining: 2m 25s
2200:	total: 2m 31s	remaining: 2m 17s
2300:	total: 2m 38s	remaining: 2m 10s
2400:	total: 2m 44s	remaining: 2m 3s
2500:	total: 2m 50s	remaining: 1m 56s
2600:	total: 2m 57s	remaining: 1m 48s
2700

<catboost.core.CatBoostClassifier at 0x7f5d3625ef10>

In [104]:
#!g1.1
# Обучение для категорий
cb_categories.fit(X_emb_categories, y_emb_categories)

AUC is not implemented on GPU. Will use CPU for metric computation, this could significantly affect learning time


0:	total: 95.6ms	remaining: 3m 58s
100:	total: 8.46s	remaining: 3m 21s
200:	total: 16.8s	remaining: 3m 12s
300:	total: 25.2s	remaining: 3m 4s
400:	total: 33.6s	remaining: 2m 55s
500:	total: 41.9s	remaining: 2m 47s
600:	total: 50.3s	remaining: 2m 39s
700:	total: 58.6s	remaining: 2m 30s
800:	total: 1m 6s	remaining: 2m 22s
900:	total: 1m 15s	remaining: 2m 13s
1000:	total: 1m 23s	remaining: 2m 5s
1100:	total: 1m 31s	remaining: 1m 56s
1200:	total: 1m 39s	remaining: 1m 48s
1300:	total: 1m 48s	remaining: 1m 39s
1400:	total: 1m 56s	remaining: 1m 31s
1500:	total: 2m 4s	remaining: 1m 22s
1600:	total: 2m 12s	remaining: 1m 14s
1700:	total: 2m 20s	remaining: 1m 5s
1800:	total: 2m 28s	remaining: 57.5s
1900:	total: 2m 36s	remaining: 49.3s
2000:	total: 2m 44s	remaining: 41.1s
2100:	total: 2m 52s	remaining: 32.8s
2200:	total: 3m 1s	remaining: 24.6s
2300:	total: 3m 9s	remaining: 16.4s
2400:	total: 3m 17s	remaining: 8.14s
2499:	total: 3m 25s	remaining: 0us


<catboost.core.CatBoostClassifier at 0x7f5d3625ecd0>

In [119]:
#!g1.1
test_emb_sem = pd.read_csv('final_result_test_sent_true_final.csv',index_col=0)
test_emb_cat = pd.read_csv('test_emb_cat_final.csv', index_col=0)

In [120]:
#!g1.1
predict_proba_sem = cb_semantic.predict_proba(test_emb_sem)

  self._init_pool(data, label, cat_features, text_features, embedding_features, pairs, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, feature_names, thread_count)


In [121]:
#!g1.1
proba_df_sem = pd.DataFrame(predict_proba_sem, columns=cb_semantic.classes_)

In [122]:
#!g1.1
test_df_sem = pd.concat([test, proba_df_sem], axis=1)

In [123]:
#!g1.1
predict_proba_cat = cb_categories.predict_proba(test_emb_cat)

  self._init_pool(data, label, cat_features, text_features, embedding_features, pairs, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, feature_names, thread_count)


In [125]:
#!g1.1
proba_df_cat = pd.DataFrame(predict_proba_cat, columns=cb_categories.classes_)

In [126]:
#!g1.1
test_df_sem = pd.concat([test_df_sem, proba_df_cat], axis=1)

In [127]:
#!g1.1
test_df_sem

Unnamed: 0,0,+,-,?,Communication,Price,Quality,Safety
0,15.03.2022 обратился на горячую линию для закр...,0.068790,0.092495,0.838716,0.916682,0.002011,0.077898,0.003409
1,"Уже который год в ТКБ не решается ""глобальная ...",0.001719,0.997239,0.001041,0.371167,0.002452,0.622238,0.004143
2,Добрый день,0.085990,0.310281,0.603729,0.962433,0.000597,0.036070,0.000900
3,"Добрый день Сегодня, зайдя в свой личный кабин...",0.001786,0.986534,0.011680,0.154141,0.004371,0.822328,0.019160
4,"Обслуживаюсь в Тинькофф пару лет, возникла жес...",0.119730,0.071195,0.809075,0.127115,0.017574,0.849488,0.005824
...,...,...,...,...,...,...,...,...
944,Отвратительный сервис и отношение к клиентам! ...,0.004005,0.994805,0.001190,0.416202,0.001345,0.581595,0.000857
945,28.04.2022 обратилась в банк о возможности пер...,0.028020,0.029387,0.942594,0.464212,0.023674,0.494884,0.017230
946,В начале 2021 года была акция по выплате 8% ке...,0.032582,0.035163,0.932255,0.022876,0.862200,0.110377,0.004547
947,Бездействие банка и некомпетентность сотрудников,0.002687,0.996156,0.001157,0.240982,0.000723,0.757429,0.000867


In [136]:
#!g1.1
# Create a function to find the second highest value in a list
def second_highest(arr):
    arr = np.array(arr)
    return arr.argsort()[-2]

# Create a new column called "Second_category"
test_df_sem["Second_category"] = 0

# Loop through each row in the dataframe
for i, row in test_df_sem.iterrows():
    # Get the four probability values
    probs = [row["Communication"], row["Price"], row["Quality"], row["Safety"]]
    # Find the index of the second highest value
    second_index = second_highest(probs)
    # If the second highest value is greater than or equal to 0.3, assign a value of 1 to the "Second_category" column
    if probs[second_index] >= 0.3:
        test_df_sem.at[i, "Second_category"] = 1

In [138]:
#!g1.1
test_df_sem.to_csv('ikanam_solution.csv')

In [139]:
#!g1.1
test_df_sem

Unnamed: 0,0,+,-,?,Communication,Price,Quality,Safety,Second_category
0,15.03.2022 обратился на горячую линию для закр...,0.068790,0.092495,0.838716,0.916682,0.002011,0.077898,0.003409,0
1,"Уже который год в ТКБ не решается ""глобальная ...",0.001719,0.997239,0.001041,0.371167,0.002452,0.622238,0.004143,1
2,Добрый день,0.085990,0.310281,0.603729,0.962433,0.000597,0.036070,0.000900,0
3,"Добрый день Сегодня, зайдя в свой личный кабин...",0.001786,0.986534,0.011680,0.154141,0.004371,0.822328,0.019160,0
4,"Обслуживаюсь в Тинькофф пару лет, возникла жес...",0.119730,0.071195,0.809075,0.127115,0.017574,0.849488,0.005824,0
...,...,...,...,...,...,...,...,...,...
944,Отвратительный сервис и отношение к клиентам! ...,0.004005,0.994805,0.001190,0.416202,0.001345,0.581595,0.000857,1
945,28.04.2022 обратилась в банк о возможности пер...,0.028020,0.029387,0.942594,0.464212,0.023674,0.494884,0.017230,1
946,В начале 2021 года была акция по выплате 8% ке...,0.032582,0.035163,0.932255,0.022876,0.862200,0.110377,0.004547,0
947,Бездействие банка и некомпетентность сотрудников,0.002687,0.996156,0.001157,0.240982,0.000723,0.757429,0.000867,0


In [None]:
#!g1.1
