In [None]:
# This is Oliver's first attempt at sort of baseline model to predict colexification in CLICS. 
# Here, I was trying to predict p(a and b colexify) in a "randomly selected" language
# More specifically, I tried doing this by trying to predict the "colex.freq" column so I couuld divide these predictions by the number of languages in CLICS
# This is because if a and b colexify 300 times out of 5000 languages, 300/5000 is a good estimate for the probability a and b colexify in a "random" language
# To predict this, I created a model that recieved pairs of embeddings of senses as input
# I then viewed predicting p(a and b colexify) as a regression task, with the dependent variable being this probability as measured above

In [None]:
from google.colab import drive
! [ -e /content ] && pip install -Uqq fastbook
! pip install torch-lr-finder
import fastbook
fastbook.setup_book()
from fastai.tabular.all import *

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from fastbook import *
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.linear_model import Ridge
import gensim.downloader as gs
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import random as r
from torch_lr_finder import LRFinder

In [None]:
wv = gs.load('word2vec-google-news-300') # These are the word2vec embeddings we are using

In [None]:
NUM_LANGUAGE = 3156 # This is the number of languages used in the CLICS database
df = pd.read_csv("gdrive/MyDrive/clics-colexification-data.csv") 
df.head()

Unnamed: 0,colex.freq,Concepticon_Gloss.xo,Concepticon_Gloss.yo,vision,assoc,affec,tax,fully_covered
0,340,TREE,WOOD,1,1,1,1,1
1,326,LEG,FOOT,1,1,1,1,1
2,296,MOON,MONTH,0,1,1,1,0
3,291,GO,WALK,0,1,1,1,0
4,284,HAND,ARM,1,1,1,1,1


In [None]:
dumb_british_spellings = {"armour":"armor", "grey":"gray", "mould":"mold", "neighbour":"neighbor", "axe":"ax", "moustache":"mustache", "plough":"plow", "mandarine":"mandarin"}
obscure_words = {"shoulderblade":"shoulder blade", "spearthrower":"spear thrower", "ridgepole":"ridge pole", "pimpleface":"pimple face", "tumpline":"backpack", "cushma":"clothing", "curassow":"tropical bird", "banisterium":"plant", "paca":"rodent", "netbag":"net bag", "muntjacs":"barking deer"}

# This converts a sense to its word2vec embeddings. More specifically, for each sense I remove punctuation and then add the vectors of each individual word. 
# I also replace some obscure words and british spellings not recognized by word2vec to similar phrases that word2vec can recognize
def stringToVec(s):  
  s = s.lower()
  for i in "(),\t\n":
    s = s.replace(i, "")
  for i in "-":
    s = s.replace(i, " ")
  for i in dumb_british_spellings:
    s = s.replace(i, dumb_british_spellings[i])
  for i in obscure_words:
    s = s.replace(i, obscure_words[i])
  arr = s.split(" ")
  vec_defined = False
  for i in arr:
    try:      # try except skips over words in the sense that aren't in word2vec, e.g., "of"
      if vec_defined:
        vec += np.array(wv[i])
      else:
        vec = np.array(wv[i])
        vec_defined = True
    except:
      continue
  if not vec_defined:
    raise Exception("Word cannot be converted to vector") # raises an error if all the words in our sense weren't in word2vec
  return vec 
def toTensor(arr):# Function to easily convert arrays to tensors
  return torch.tensor(arr, dtype = torch.float32) 

In [None]:
dep_var = "colex.freq" # Column name of dependent variable

dic = {} # This is a dictionary from which we create a new dataframe containing pairs of embeddings of senses rather than senses themselves
embed_length = len(wv["test"]) # Length of vectors in word2vec
for j in range(2*embed_length): # columns 0 - 599 are embeddings of the first sense, 600 - 1199 are embeddings of the second sense
  dic[j] = []
dic[dep_var] = []

vec_dic = {} # This will be a dictionary that easily allows us to access the embedding for all of our senses. Note this is not as simple as just calling wv[s] as some senses are phrases
sensePairs = set() # This is a set that keeps track of all the pairs of senses that appear in our data set, which will be used to add pairs of senses that don't appear
for i in range(len(df)):
  row = df.iloc[i]
  x = row["Concepticon_Gloss.xo"]
  y = row["Concepticon_Gloss.yo"]
  sensePairs.add(x+y)
  sensePairs.add(y+x)
  try:     # We use "try except" to pass over senses that give an error when trying to convert to word2vec embeddings
    if x in vec_dic:
      xvec = vec_dic[x]
    else:
      xvec = stringToVec(x)
      vec_dic[x] = xvec
    
    if y in vec_dic:
      yvec = vec_dic[x]
    else:
      yvec = stringToVec(y)
      vec_dic[y] = yvec
    
    for j in range(embed_length):
      dic[j].append(xvec[j])
    for j in range(embed_length):
      dic[embed_length+j].append(yvec[j])
    dic[dep_var].append(row[dep_var]) # Divides by NUM_LANGUAGE to convert colexification frequencies to probabilities
  except:
    continue

# To get more data, we now add pairs of senses that never colexify in our data set, giving probabilities 0
# The number of pairs we add it 10% of the size of our data set 

num_zeros = int(len(df)/10)
senses = list(vec_dic.keys())
sense_indices = {senses[i]:i for i in range(len(senses))}
for i in range(num_zeros):
  x = r.randint(0, len(senses)-1) # Here we take two random senses that never colexify
  y = r.randint(0, len(senses)-1)
  while senses[x]+senses[y] in sensePairs:
    x = r.randint(0, len(senses)-1)
    y = r.randint(0, len(senses)-1)
  sensePairs.add(senses[x]+senses[y])
  xvec = stringToVec(senses[x])
  yvec = stringToVec(senses[y])
  word_arr = np.concatenate((xvec, yvec))
  for j in range(len(word_arr)):
    dic[j].append(word_arr[j])
  dic[dep_var].append(0)

df_w2v = pd.DataFrame.from_dict(dic) # This is our new dataframe. We have replaced Conceptican_gloss.x, Conceptican_gloss.y with their embeddings
df_w2v.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,512,513,514,515,516,517,518,519,520,521,522,523,524,525,526,527,528,529,530,531,532,533,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,colex.freq
0,0.484375,0.122559,-0.157227,0.034668,-0.219727,-0.235352,0.113281,0.02771,0.132812,0.287109,0.105469,-0.241211,0.019897,0.033203,-0.069336,-0.082031,-0.259766,-0.1875,-0.006439,0.090332,0.007599,-0.07666,-0.10498,-0.125,0.189453,-0.121582,-0.18457,0.047852,0.220703,-0.257812,-0.047607,-0.219727,-0.030273,-0.134766,-0.04541,-0.28125,-0.066406,-0.373047,0.0271,0.022461,0.150391,-0.146484,0.146484,-0.208008,0.128906,-0.240234,-0.294922,0.07959,0.025513,0.071777,-0.044189,0.115723,0.091797,-0.037598,0.279297,-0.063477,0.022827,0.147461,-0.02832,-0.077148,0.082031,-0.134766,0.209961,0.085449,-0.164062,-0.116699,0.166992,-0.09375,-0.111328,-0.002975,0.135742,-0.206055,-0.078125,-0.433594,0.049316,0.118652,0.210938,-0.351562,0.100586,-0.002914,0.097656,0.478516,-0.014099,-0.15918,-0.121582,0.035889,-0.204102,-0.027954,-0.044922,0.198242,-0.119629,0.133789,0.00473,-0.225586,-0.322266,-0.046875,0.036377,-0.310547,0.048828,-0.322266,-0.182617,-0.175781,0.248047,0.269531,0.214844,0.026123,0.100586,-0.109375,-0.388672,-0.3125,-0.597656,-0.322266,0.068359,-0.067871,0.125977,-0.222656,0.164062,0.024414,-0.251953,0.084961,0.046143,0.104492,-0.087402,0.172852,0.135742,0.229492,0.059082,-0.090332,0.042236,0.151367,-0.117676,-0.335938,0.017212,-0.369141,0.080566,0.384766,-0.210938,-0.135742,0.242188,0.082031,0.013794,0.03125,-0.11377,0.294922,0.157227,0.205078,0.028198,-0.227539,-0.171875,-0.036133,0.435547,0.044678,-0.24707,0.3125,-0.026367,-0.199219,0.008789,-0.085449,-0.036621,-0.112305,-0.179688,0.106445,0.170898,0.182617,-0.049805,-0.287109,0.12793,-0.087891,-0.34375,0.047119,-0.168945,0.223633,-0.241211,-0.083984,-0.324219,-0.137695,0.147461,0.200195,-0.100586,0.277344,-0.310547,-0.130859,-0.052246,0.11377,-0.040283,0.140625,-0.054688,0.012634,-0.037109,0.047119,0.091797,-0.140625,-0.429688,-0.011169,0.216797,-0.021973,0.042969,-0.21582,-0.084473,-0.503906,0.004852,-0.155273,-0.253906,0.086426,0.028809,0.099609,0.285156,0.241211,0.064941,-0.108398,-0.527344,-0.116699,-0.275391,-0.108398,0.047119,0.143555,-0.136719,-0.09668,0.353516,0.445312,0.010498,-0.208008,0.020874,0.023071,-0.126953,-0.066895,0.05957,0.067383,0.062012,0.039062,-0.134766,-0.181641,-0.135742,-0.034668,-0.152344,0.04248,0.152344,-0.161133,0.029053,-0.150391,-0.073242,-0.214844,-0.104492,-0.392578,-0.211914,0.037354,0.048096,0.043213,-0.026489,-0.208008,-0.457031,-0.047119,-0.043213,-0.149414,0.140625,0.006012,-0.527344,0.012085,-0.230469,-0.02124,-0.112793,-0.103516,-0.255859,-0.267578,0.12793,0.388672,-0.246094,0.095703,-0.394531,0.029907,0.019287,-0.014709,-0.038818,-0.008789,-0.004456,0.024048,0.07959,-0.110352,-0.102539,0.214844,0.081055,-0.1875,0.251953,0.087891,-0.069824,-0.162109,-0.134766,0.114258,-0.023438,0.019775,-0.083496,0.206055,-0.1875,0.172852,0.029419,0.035889,-0.118652,-0.088867,-0.042969,0.019165,0.029053,0.237305,-0.089355,-0.075684,0.099121,0.048096,0.133789,-0.328125,0.296875,0.326172,0.014343,-0.316406,-0.213867,-0.320312,0.134766,-0.294922,-0.133789,0.101562,-0.054199,0.132812,-0.059082,-0.105957,0.040283,0.017334,-0.100586,-0.071289,-0.231445,0.163086,0.11377,-0.400391,-0.003998,0.128906,0.001938,-0.141602,0.02063,-0.12793,-0.198242,0.034668,0.003525,0.122559,0.125,-0.110352,0.298828,0.049316,-0.056396,-0.154297,-0.12207,0.273438,-0.099121,-0.055176,0.098633,0.242188,-0.224609,-0.098633,-0.208984,-0.196289,-0.034912,0.00082,-0.155273,-0.222656,0.210938,0.067383,-0.144531,-0.177734,-0.194336,0.026489,-0.033447,-0.053467,-0.241211,0.273438,0.223633,-0.330078,0.279297,-0.292969,-0.145508,0.181641,-0.045166,-0.099609,0.022339,-0.143555,-0.198242,0.03064,-0.073242,0.125,0.168945,0.12207,-0.023193,0.214844,-0.074707,0.011353,-0.177734,0.243164,0.038574,0.21582,-0.146484,-0.267578,-0.067871,-0.192383,0.196289,-0.421875,-0.020508,-0.062988,-0.013245,0.017334,-0.114258,-0.086426,0.394531,-0.015869,-0.216797,-0.055176,-0.267578,-0.176758,-0.132812,0.202148,-0.134766,0.05542,0.168945,0.031494,-0.214844,0.108887,0.064941,0.046143,-0.195312,-0.043701,0.042725,0.013123,-0.125,0.071777,-0.232422,0.253906,-0.12793,-0.172852,-0.146484,-0.070801,0.214844,0.253906,-0.339844,-0.162109,0.238281,0.163086,0.316406,0.094238,-0.11377,0.02002,-0.037109,0.249023,-0.03418,-0.060791,-0.216797,0.213867,0.053223,-0.089355,-0.092773,0.382812,-0.048828,0.015076,-0.013123,-0.237305,-0.211914,-0.078613,-0.087402,-0.05542,-0.083008,-0.061279,-0.013977,0.15332,0.068359,-0.169922,-0.119629,-0.043945,-0.11084,-0.033936,-0.104492,-0.194336,-0.138672,0.086426,0.275391,-0.126953,-0.150391,0.073242,-0.068359,-0.227539,0.220703,0.326172,0.275391,-0.018311,0.020508,-0.112793,0.498047,0.060059,-0.065918,-0.227539,0.028076,0.137695,0.196289,0.039551,0.225586,-0.182617,-0.081055,-0.09375,0.085938,0.07959,-0.154297,-0.099121,-0.049561,0.137695,0.357422,-0.296875,-0.073242,-0.004608,-0.441406,-0.165039,-0.300781,0.294922,0.466797,0.171875,-0.394531,0.099121,0.017212,0.033447,-0.195312,-0.103027,-0.333984,0.269531,0.318359,0.010681,-0.030762,0.229492,0.119629,0.015625,-0.117676,0.07666,-0.025391,0.198242,-0.296875,-0.161133,-0.168945,-0.015625,0.125,-0.112305,0.166016,0.009949,0.220703,-0.163086,-0.085938,-0.400391,0.056641,0.128906,0.027344,-0.183594,0.118652,0.179688,0.192383,-0.28125,0.164062,-0.05249,-0.257812,-0.158203,0.083496,0.037109,-0.167969,-0.060791,-0.349609,-0.238281,0.103027,0.071289,-0.03833,-0.084473,0.135742,-0.020386,0.275391,-0.058838,0.220703,0.051025,-0.177734,0.011658,-0.005035,-0.191406,-0.052002,0.300781,-0.042236,-0.186523,0.073242,-0.027832,-0.035156,-0.384766,-0.345703,-0.015198,-0.158203,0.488281,-0.291016,0.06543,-0.024658,0.112793,0.170898,0.104492,0.126953,-0.009216,0.092773,0.146484,340
1,0.000908,0.120117,0.021973,-0.279297,-0.118652,-0.089844,0.071777,-0.224609,-0.139648,0.05957,-0.044678,-0.043701,0.131836,-0.006866,0.026733,0.122559,-0.251953,-0.007721,-0.084473,0.102539,-0.024292,0.095215,0.209961,0.192383,0.044922,-0.402344,-0.18457,0.183594,0.244141,-0.108398,0.035645,0.19043,0.172852,-0.099121,-0.138672,0.235352,-0.031494,0.049561,0.050049,0.140625,-0.155273,0.020264,0.351562,0.133789,0.076172,0.165039,0.075195,0.009888,0.200195,0.22168,0.031494,0.259766,-0.076172,-0.140625,0.064941,-0.031494,0.201172,0.223633,0.141602,-0.120117,-0.15918,0.062988,0.144531,-0.223633,-0.292969,0.029419,0.206055,0.011169,0.084473,0.132812,0.063965,0.089355,-0.070801,-0.158203,-0.365234,-0.072754,0.03064,-0.178711,0.149414,-0.167969,0.111816,0.058838,0.040771,-0.11377,0.007202,-0.00824,-0.094727,0.020142,0.135742,-0.0065,-0.152344,0.028687,0.123535,0.056152,-0.101562,-0.273438,-0.147461,-0.271484,-0.023804,-0.049805,-0.094238,-0.378906,-0.060791,-0.161133,0.081543,-0.106934,0.060059,-0.118164,0.139648,-0.233398,-0.027344,-0.263672,0.462891,-0.009338,-0.271484,-0.119141,0.154297,0.07666,-0.021606,0.106934,-0.073242,0.021606,0.099609,0.261719,0.008789,0.042969,0.033691,-0.102539,-0.092285,0.124023,-0.21875,0.138672,-0.21582,-0.367188,-0.134766,-0.208008,-0.160156,0.222656,-0.219727,0.414062,0.194336,0.033691,0.124512,0.333984,-0.119141,0.097656,-0.176758,-0.140625,0.026733,0.089844,0.087402,-0.114746,0.071289,0.294922,0.371094,-0.101074,-0.216797,-0.263672,0.031494,-0.324219,0.023804,0.138672,0.353516,-0.283203,0.00769,-0.060303,0.166016,-0.179688,-0.117188,-0.196289,-0.139648,-0.154297,0.018921,-0.347656,-0.4375,-0.234375,-0.015564,-0.03064,-0.115234,0.07959,-0.114258,-0.04834,0.124023,0.091309,-0.004211,0.095215,0.130859,0.07959,0.273438,-0.016724,0.087402,-0.087402,0.220703,0.047607,0.01178,0.199219,0.09668,0.061035,-0.074219,-0.025513,-0.07666,-0.112305,-0.091797,0.294922,0.279297,-0.435547,-0.160156,0.059814,0.134766,-0.104492,-0.259766,0.28125,-0.324219,0.197266,0.002274,-0.298828,-0.010132,-0.119141,-0.070801,-0.198242,0.037109,-0.124023,0.101562,-0.061523,-0.067871,-0.169922,0.073242,-0.128906,0.083496,0.019653,0.203125,0.145508,-0.080566,-0.049805,0.000479,0.188477,-0.015381,0.203125,0.07959,0.208984,0.143555,-0.086914,0.089844,-0.180664,0.25,0.040039,-0.186523,-0.036865,0.259766,-0.143555,0.0177,-0.075195,0.019287,-0.046387,0.111816,-0.019531,-0.087402,0.332031,-0.113281,-0.176758,-0.062988,-0.067871,0.054199,0.060059,0.124512,-0.083008,0.209961,-0.120605,-0.041016,0.079102,-0.138672,0.167969,0.124512,-0.089355,0.189453,-0.273438,0.119141,-0.006653,-0.116699,-0.043213,0.142578,-0.149414,0.051025,-0.306641,-0.097168,0.004852,-0.208984,0.171875,0.050537,-0.063965,0.094238,-0.019897,-0.026611,-0.162109,0.206055,0.114258,0.162109,-0.008179,0.008606,0.047607,0.031982,0.279297,0.033203,-0.24707,0.014404,-0.058594,0.205078,-0.097168,0.017944,0.116699,0.265625,0.01355,0.176758,-0.039062,0.294922,-0.114258,-0.092773,0.072266,0.002533,0.084473,-0.117676,0.242188,-0.227539,0.085938,-0.210938,-0.207031,-0.314453,0.080078,0.277344,-0.28125,0.098145,0.001625,0.189453,-0.203125,-0.145508,-0.09375,-0.031738,0.169922,-0.433594,-0.071777,0.208984,-0.073242,0.162109,-0.193359,0.02832,0.125,-0.121582,0.177734,0.074219,-0.400391,0.021606,0.233398,-0.074707,-0.347656,0.170898,0.059082,0.052002,-0.075195,-0.019531,-0.137695,0.066406,-0.013062,-0.12793,-0.122559,0.049072,-0.05127,-0.000683,0.225586,-0.003998,0.265625,0.102539,0.02417,-0.070312,-0.003891,-0.111816,-0.098145,0.058838,-0.123047,-0.10791,-0.137695,0.009827,-0.022217,0.116211,-0.128906,-0.003876,-0.004089,-0.067383,0.022583,0.363281,-0.214844,-0.166992,0.020142,0.160156,0.22168,0.161133,-0.068848,0.033447,-0.207031,0.074707,0.022949,-0.125977,-0.057129,0.186523,-0.170898,0.041992,0.113281,0.061523,-0.095703,-0.146484,-0.101562,-0.145508,0.012939,0.112305,0.21875,-0.09375,0.08252,0.196289,0.00766,-0.026367,0.015015,-0.087402,-0.051758,-0.07666,0.322266,0.173828,-0.006134,0.222656,-0.195312,-0.038574,0.150391,-0.3125,0.292969,0.013855,-0.041504,-0.269531,0.141602,-0.109863,0.033447,-0.047119,0.139648,0.183594,0.012512,-0.012146,0.123535,-0.070312,0.032715,0.044922,-0.15332,-0.226562,0.171875,0.133789,-0.013367,-0.376953,0.246094,0.183594,0.160156,-0.039551,-0.243164,0.018311,0.055664,0.15625,0.014771,0.22168,-0.400391,0.125977,-0.026245,0.150391,-0.275391,-0.382812,-0.074707,-0.078613,-0.022095,-0.198242,0.022461,-0.050293,-0.081055,0.137695,-0.033203,-0.419922,0.033447,-0.079102,-0.105957,0.212891,0.07959,-0.010681,-0.142578,0.11377,0.139648,0.166992,0.033447,0.111328,-0.150391,0.116699,0.025024,-0.072266,0.271484,-0.037842,0.212891,-0.010193,-0.128906,0.050537,-0.103027,-0.055908,0.259766,0.279297,-0.223633,0.169922,-0.131836,0.163086,0.138672,-0.251953,-0.166992,-0.021362,0.013245,-0.333984,-0.053955,-0.056152,-0.057129,0.124512,0.040527,-0.025513,0.030151,-0.021973,-0.296875,0.242188,-0.052734,0.002258,0.242188,0.15332,-0.257812,0.005341,-0.137695,-0.144531,0.05835,-0.129883,0.161133,0.013306,-0.035889,0.00135,0.053467,0.011169,0.213867,-0.029175,-0.176758,-0.014099,0.080566,-0.214844,-0.155273,0.231445,-0.217773,0.242188,-0.216797,0.064453,0.207031,0.263672,-0.249023,0.051514,0.310547,-0.231445,-0.057617,-0.062988,0.049805,-0.140625,-0.04541,-0.089844,0.155273,0.102051,-0.031982,-0.055908,0.129883,-0.022217,0.152344,-0.000935,-0.241211,-0.024048,0.067871,-0.147461,-0.077637,0.09668,-0.091309,0.106445,0.014526,0.112305,0.06543,-0.202148,0.026489,-0.355469,0.077637,0.139648,-0.017944,-0.0065,0.133789,-0.201172,0.032715,0.168945,0.084473,-0.125,0.025757,0.062988,-0.132812,326
2,-0.038574,0.189453,0.206055,0.171875,0.054199,-0.224609,0.414062,-0.353516,0.214844,0.056885,0.188477,0.15918,-0.062256,-0.050293,-0.006226,0.255859,-0.094238,0.004913,0.291016,-0.060303,-0.394531,-0.040771,0.163086,-0.054443,0.025146,0.078613,-0.013611,0.027466,0.431641,0.137695,-0.419922,-0.175781,-0.179688,0.041748,-0.086914,-0.112793,0.010132,-0.013977,0.177734,0.012756,0.103516,0.047119,0.410156,0.285156,0.10791,-0.170898,-0.055908,-0.308594,0.043213,-0.030762,-0.012024,0.243164,-0.102539,0.150391,-0.124512,-0.347656,0.435547,-0.053467,-0.045898,-0.042236,-0.032471,-0.231445,0.11084,0.003479,-0.217773,0.072266,-0.116699,0.045166,0.131836,0.091309,0.040527,-0.296875,0.046631,0.141602,-0.384766,0.053223,0.110352,0.166016,-0.138672,0.234375,0.139648,-0.111816,0.273438,-0.439453,-0.033691,-0.157227,0.155273,-0.392578,0.355469,0.050537,0.058594,0.004242,-0.220703,-0.294922,0.044189,-0.121582,0.373047,-0.039307,0.378906,-0.031982,-0.103516,-0.067871,0.196289,-0.002762,0.056641,0.239258,-0.05249,-0.001328,-0.053223,-0.102539,-0.178711,0.104004,0.129883,0.15625,0.044678,0.056885,0.010376,-0.025146,-0.220703,0.027222,-0.137695,-0.237305,-0.013611,-0.011658,-0.142578,-0.289062,0.00946,0.143555,0.300781,0.05127,-0.063477,-0.01178,-0.209961,0.089844,-0.015015,0.251953,-0.047119,-0.046631,0.25,0.421875,-0.078125,-0.151367,-0.242188,0.300781,0.158203,0.083984,-0.053711,-0.125977,0.11377,-0.07666,0.245117,-0.051025,-0.124023,0.013123,0.006927,-0.416016,0.165039,0.054688,0.100586,-0.029175,-0.367188,0.072266,0.255859,-0.145508,-0.000542,-0.285156,0.251953,0.136719,-0.197266,0.104004,-0.044678,0.330078,0.096191,0.349609,-0.006104,-0.222656,-0.130859,-0.133789,-0.128906,-0.242188,-0.355469,0.01709,-0.194336,-0.200195,-0.068359,-0.171875,-0.155273,0.100098,0.064453,-0.108887,0.07959,0.044434,-0.055664,-0.106934,-0.363281,0.296875,0.40625,-0.022217,0.079102,-0.008789,0.032471,-0.133789,-0.253906,-0.269531,0.194336,-0.069824,0.015259,-0.130859,-0.314453,0.162109,-0.248047,0.1875,0.09082,-0.149414,0.182617,0.04541,0.085938,0.118164,0.018066,-0.210938,0.347656,0.07666,0.449219,-0.050049,0.069336,0.12793,0.012329,0.086426,0.339844,0.082031,0.378906,-0.192383,0.12793,0.070312,0.208008,-0.160156,-0.279297,-0.261719,-0.057617,-0.279297,0.144531,-0.173828,0.043701,-0.150391,-0.003235,-0.287109,0.079102,-0.037109,0.032227,0.014709,-0.189453,-0.271484,0.259766,0.287109,-0.125977,-0.308594,0.150391,-0.265625,-0.050293,-0.236328,-0.06543,-0.425781,-0.19043,-0.152344,0.238281,0.122559,-0.257812,-0.238281,0.172852,0.064941,-0.084473,-0.061035,0.347656,0.096191,0.075195,0.205078,-0.236328,-0.080566,-0.145508,-0.115234,0.008057,0.308594,0.053711,0.176758,0.332031,-0.138672,0.099609,-0.00824,-0.004364,0.060059,0.056641,0.170898,-0.257812,-0.039795,-0.101074,0.005981,-0.025391,-0.335938,0.164062,-0.253906,0.038818,0.15918,-0.088867,0.06543,-0.098633,-0.04541,0.21875,-0.243164,0.139648,0.154297,0.05127,0.103027,0.10498,-0.177734,-0.023804,0.158203,-0.029053,-0.080566,-0.138672,-0.062988,0.027954,-0.050537,-0.166016,-0.181641,-0.063965,0.016479,0.275391,0.188477,0.057373,0.090332,0.112793,-0.119629,0.006195,0.021484,0.043945,-0.022095,-0.24707,-0.205078,-0.310547,0.006775,0.044434,-0.117188,0.223633,0.131836,-0.216797,-0.050537,0.095703,0.1875,0.066895,0.057861,0.143555,-0.226562,-0.125977,-0.048828,0.011169,0.157227,-0.07666,-0.006989,0.077148,-0.132812,-0.120117,0.157227,-0.013062,0.111816,0.043945,0.101562,0.03833,-0.018799,0.200195,0.116211,0.042725,-0.070312,-0.062012,0.19043,0.071289,0.006195,0.049072,0.091309,0.212891,0.1875,0.034424,-0.195312,-0.091309,0.134766,0.067383,-0.078125,-0.170898,-0.105469,0.024292,0.056641,0.222656,0.181641,-0.207031,-0.408203,-0.320312,-0.225586,0.05542,0.102539,0.110352,-0.096191,0.160156,-0.104492,-0.200195,-0.009766,-0.116699,-0.253906,-0.095703,0.100586,-0.047363,-0.083984,0.033936,0.098633,0.172852,-0.062256,0.098145,-0.072754,0.1875,0.067871,0.130859,0.066895,0.057861,-0.198242,-0.135742,-0.125,0.223633,0.056396,0.229492,-0.050537,0.064453,-0.058594,-0.137695,-0.104004,0.07373,-0.125,-0.176758,-0.161133,0.025391,-0.014465,-0.129883,0.099121,0.128906,-0.057617,-0.103516,0.047363,-0.077637,0.043457,-0.031494,-0.04126,-0.213867,-0.150391,0.030029,0.123535,-0.036133,-0.09082,0.201172,-0.146484,-0.08252,-0.123047,0.051758,0.071289,0.099609,0.063965,0.121094,-0.011047,-0.062256,0.019531,0.174805,-0.084473,0.072754,-0.177734,0.077637,0.020996,-0.125,-0.233398,0.054932,-0.171875,0.046631,0.095215,0.044434,0.152344,-0.251953,0.035156,-0.10498,0.030884,0.214844,0.09375,0.071289,0.076172,0.105469,-0.173828,0.158203,-0.027588,0.030151,-0.035156,-0.216797,0.090332,0.144531,0.03833,-0.116211,0.005341,-0.085938,0.054199,0.01239,0.057861,0.208984,-0.135742,-0.096191,0.133789,0.053467,-0.114258,0.091797,0.010315,-0.145508,0.000187,-0.202148,-0.022949,0.088379,-0.125977,-0.164062,0.208984,0.119141,0.133789,0.124023,-0.069824,-0.102051,-0.277344,0.004547,0.080566,-0.084473,-0.125977,-0.072266,-0.025146,-0.029907,0.198242,-0.038818,0.039795,0.069336,0.059326,0.034912,-0.012268,-0.040771,0.125,0.102539,-0.199219,-0.099121,-0.125977,0.078125,0.115723,0.079102,0.054688,-0.07666,-0.053955,-0.137695,0.060791,0.098145,0.251953,0.115234,0.040771,0.174805,0.040527,0.07666,-0.094727,0.031006,-0.004913,0.19043,-0.053467,-0.10791,0.145508,0.097656,0.128906,0.168945,0.036621,0.147461,0.037598,0.135742,-0.064941,-0.123047,-0.062256,0.018433,-0.196289,-0.04126,0.006348,0.031982,0.138672,0.177734,-0.143555,-0.11377,0.112305,0.170898,0.061768,-0.015137,-0.093262,-0.166016,-0.055664,0.031738,-0.152344,-0.126953,0.152344,-0.01001,-0.149414,296
3,-0.026367,0.068359,-0.031128,0.219727,0.003418,-0.009033,0.10791,-0.174805,0.077148,0.000383,-0.102539,-0.017334,-0.030884,0.057617,-0.109863,0.061035,0.248047,0.005463,0.034912,0.00766,-0.10791,0.216797,0.126953,0.146484,0.155273,0.044678,0.075195,-0.145508,-0.077148,-0.085449,-0.011597,0.07959,-0.194336,-0.257812,-0.098633,-0.138672,-0.000414,-0.090332,0.07666,0.133789,0.051758,-0.048096,0.188477,-0.034424,-0.07959,0.0354,-0.102539,-0.133789,-0.106445,0.074219,-0.024658,0.199219,0.154297,-0.122559,0.047119,-0.300781,-0.07373,-0.064453,0.09082,0.066895,-0.036865,-0.148438,-0.012451,0.071289,-0.063477,-0.205078,-0.239258,0.091797,-0.035645,0.06543,0.074707,0.134766,0.141602,-0.04541,-0.191406,-0.261719,0.020142,0.117188,0.035156,0.080078,-0.001984,-0.091797,0.12793,0.029541,0.09668,0.088867,0.056641,0.06543,-0.029297,0.021973,0.122559,0.310547,0.013611,-0.168945,-0.039795,-0.09668,0.196289,0.263672,-0.021362,0.00267,-0.145508,0.074219,0.021484,0.058594,-0.050293,-0.198242,0.052979,-0.174805,0.102051,-0.01239,0.12793,-0.102539,-0.125,-0.179688,-0.031006,0.040039,0.078613,-0.193359,0.142578,0.033936,0.009827,0.046875,-0.002914,0.275391,-0.007507,0.163086,-0.048828,-0.074707,-0.12793,-0.052002,-0.15332,-0.142578,0.140625,0.186523,0.025269,-0.10791,0.09375,0.210938,0.116699,0.102051,0.116699,-0.079102,0.043457,0.117188,0.124512,0.098145,-0.24707,-0.235352,-0.103027,-0.112305,-0.069824,0.126953,-0.024292,0.055176,-0.109375,-0.103027,0.095215,-0.195312,-0.152344,0.148438,0.088379,0.04834,0.013733,-0.061279,0.050049,-0.310547,0.093262,0.012146,0.095703,-0.162109,-0.233398,-0.261719,-0.05127,-0.074219,-0.049072,0.106445,0.113281,-0.029785,0.054688,0.160156,-0.212891,-0.142578,-0.016602,0.154297,0.01416,-0.099121,0.084473,0.087402,0.200195,-0.050781,0.207031,0.12793,0.175781,0.178711,-0.015869,0.004059,-0.005402,-0.008484,-0.001236,-0.142578,-0.098145,0.039062,0.012268,0.000633,0.052734,0.09375,-0.032227,-0.289062,0.12793,-0.005981,0.076172,0.071777,0.083496,0.020752,-0.196289,0.067871,0.208008,-0.010864,-0.044434,-0.277344,-0.117676,-0.133789,0.081543,0.012512,0.181641,-0.075195,0.112793,0.075684,0.050537,-0.058594,0.052979,-0.043701,-0.058838,0.154297,0.00946,0.116211,-0.151367,0.011108,0.141602,-0.05542,0.183594,0.117188,-0.12793,-0.070801,0.003723,-0.041992,-0.091309,0.109375,0.12793,0.022583,0.193359,0.054688,0.022583,-0.032227,0.125,-0.050049,-0.033447,-0.006165,-0.032715,-0.03064,-0.131836,-0.067383,-0.10498,0.126953,0.179688,0.075684,0.108398,0.089355,-0.017212,0.09082,0.006653,0.230469,0.033203,0.169922,0.149414,-0.208008,-0.059814,-0.165039,-0.091797,-0.000147,0.082031,-0.125,0.009888,0.107422,-0.010437,-0.019653,-0.023071,0.088867,-0.014893,0.21875,0.04126,0.298828,-0.232422,0.165039,-0.022339,-0.064453,0.15918,-0.170898,0.196289,-0.099609,0.000248,-0.008728,-0.125,0.457031,-0.060059,-0.05127,-0.048584,-0.208984,-0.154297,0.140625,0.253906,-0.098145,0.02417,0.129883,0.080566,0.102051,0.287109,0.056885,0.073242,-0.114746,-0.043945,0.248047,-0.116211,0.086426,0.080078,-0.199219,-0.061523,-0.045898,0.006317,-0.34375,0.014587,0.150391,-0.138672,-0.298828,-0.255859,-0.423828,0.106445,-0.285156,0.016602,0.105469,-0.197266,-0.144531,0.057373,-0.080566,-0.087402,0.133789,-0.197266,-0.168945,-0.197266,0.020508,-0.043457,0.147461,0.112793,0.066895,0.265625,-0.263672,0.020386,-0.15625,0.056396,0.036621,0.104492,0.031738,0.129883,0.03833,0.162109,-0.05957,-0.230469,0.198242,-0.061523,0.182617,-0.020508,0.126953,0.052246,-0.25,-0.203125,-0.075684,-0.181641,-0.063965,0.062988,-0.039307,-0.283203,-0.217773,0.098145,0.008301,0.169922,0.126953,-0.220703,0.135742,-0.034912,-0.116211,-0.0271,0.049316,-0.188477,-0.025024,-0.118652,-0.335938,0.125,0.212891,-0.043701,-0.01709,-0.101562,-0.071289,-0.128906,0.120117,0.121582,0.098633,0.161133,-0.306641,0.039062,-0.055176,-0.025146,-0.108398,0.055176,-0.169922,0.15918,0.022217,-0.150391,-0.253906,-0.022583,0.304688,-0.245117,0.21582,-0.082031,0.291016,-0.019531,0.091797,0.121582,0.25,0.086914,0.05249,-0.123535,-0.039551,0.079102,0.141602,-0.040283,-0.011963,-0.067383,0.035645,-0.141602,0.15332,0.063477,0.079102,-0.063965,-0.061768,-0.013184,0.176758,0.08252,-0.287109,0.145508,-0.267578,-0.148438,-0.072266,0.10791,0.130859,0.150391,0.161133,-0.051514,-0.216797,-0.283203,0.196289,0.181641,-0.003647,0.226562,-0.292969,-0.131836,-0.237305,0.277344,-0.06543,-0.093262,-0.139648,-0.318359,-0.036133,0.113281,-0.049072,-0.095703,0.080078,0.03418,0.097656,-0.068848,0.047363,-0.18457,-0.021118,0.074219,0.041992,-0.066895,-0.067871,0.326172,0.132812,0.042969,-0.081055,-0.003876,-0.03418,0.091309,0.177734,-0.076172,-0.029297,-0.145508,-0.039795,-0.122559,0.255859,-0.269531,-0.123535,-0.126953,-0.236328,0.175781,0.006378,-0.028198,-0.002502,0.140625,-0.072754,-0.086426,-0.077637,-0.023438,-0.100098,-0.037842,0.141602,0.179688,0.028564,-0.033936,-0.296875,0.209961,0.032471,0.237305,-0.109375,0.226562,0.057861,0.194336,-0.022705,0.191406,-0.237305,0.316406,0.002136,-0.064941,0.170898,0.259766,0.018921,-0.145508,0.130859,0.289062,-0.029419,0.05127,-0.062012,-0.029297,-0.224609,0.132812,-0.022095,0.142578,0.029541,-0.049072,-0.177734,0.066406,-0.034668,-0.051025,-0.100098,0.109863,0.009888,0.136719,0.064453,-0.216797,0.101074,-0.129883,-0.271484,-0.173828,0.212891,0.251953,0.033203,0.347656,0.185547,-0.208984,0.189453,0.034912,0.209961,0.185547,0.147461,0.085449,-0.1875,-0.229492,0.04248,-0.001999,0.05127,0.22168,-0.02063,0.183594,0.114258,0.112793,-0.148438,-0.205078,0.04541,-0.283203,0.111328,-0.225586,0.310547,-0.175781,0.141602,0.072266,0.059814,-0.037354,-0.279297,-0.07959,-0.073242,291
4,0.093262,-0.046387,-0.133789,0.048096,-0.316406,0.10791,0.300781,-0.28125,-0.018311,0.048584,-0.044434,-0.182617,-0.117188,0.137695,-0.178711,0.003342,-0.107422,0.013428,0.089844,0.072266,0.216797,0.103516,0.189453,0.057617,0.037598,-0.251953,-0.059814,-0.010071,-0.028442,0.12793,0.013,0.014343,-0.10791,0.12793,-0.178711,0.046143,0.006409,0.116699,-0.086914,-0.15332,-0.047119,-0.013977,0.118164,0.050293,0.199219,0.009155,-0.083008,0.029785,0.168945,-0.055176,-0.134766,-0.011658,-0.099121,-0.261719,-0.050049,0.048096,-0.071289,0.056641,-0.005463,-0.050537,-0.064941,0.035156,-0.09375,0.044434,-0.017578,-0.072754,0.033691,-0.024048,0.074219,0.249023,0.199219,-0.141602,0.140625,-0.008057,-0.097168,-0.173828,0.087891,0.243164,0.096191,0.029907,0.036621,-0.088867,0.175781,-0.263672,0.013855,-0.080566,0.04126,-0.078125,0.158203,0.074707,-0.158203,0.05127,0.095215,0.024658,-0.101562,0.04541,-0.00148,0.05542,-0.135742,0.069824,-0.11084,-0.371094,-0.071777,0.009644,0.142578,-0.267578,0.150391,0.053467,0.275391,-0.11084,0.112305,0.038818,-0.042725,-0.003845,-0.080078,0.053711,0.229492,-0.023071,-0.15332,-0.018188,-0.058594,0.10791,-0.014404,-0.059082,-0.039795,0.05542,0.216797,-0.134766,-0.060059,0.026855,0.026123,0.018677,-0.180664,-0.102539,-0.007568,-0.047119,-0.088867,0.15918,-0.05835,-0.028076,0.095703,-0.095215,0.004669,-0.086426,0.097168,0.091309,-0.145508,-0.126953,0.131836,-0.031982,-0.061768,0.014954,-0.211914,0.108398,-0.072754,0.15625,-0.164062,-0.412109,0.036377,0.077148,-0.037354,-0.126953,-0.046143,-0.182617,-0.048096,-0.093262,-0.075684,-0.237305,-0.238281,-0.25,-0.040771,0.070801,0.011047,-0.277344,-0.115723,-0.193359,0.016113,-0.222656,-0.135742,-0.001846,-0.271484,0.02063,0.119141,0.02002,-0.031494,0.081055,-0.062988,0.042969,0.016113,0.097168,0.209961,0.098633,-0.087402,-0.103027,-0.012146,-0.05542,-0.014404,-0.01709,-0.291016,-0.21582,-0.214844,0.214844,-0.152344,0.138672,0.140625,-0.037598,-0.07959,0.003769,-0.05542,-0.114746,0.041992,-0.060547,-0.072754,0.095215,0.28125,-0.210938,-0.030884,0.193359,-0.024902,-0.177734,0.068359,0.144531,-0.179688,-0.205078,-0.079102,-0.142578,0.030518,0.050537,0.261719,-0.091797,0.035645,0.076172,-0.142578,0.049072,0.049072,-0.065918,-0.086914,0.067871,0.061768,0.168945,0.030151,0.072266,-0.055664,0.012146,-0.024414,-0.15332,-0.054932,-0.208008,0.102539,-0.015625,-0.107422,-0.057373,-0.132812,0.018555,0.202148,-0.103516,-0.021729,0.069336,0.058105,-0.077637,-0.1875,-0.083496,-0.07959,-0.091797,-0.121582,0.077637,-0.091309,0.069824,-0.138672,5.2e-05,-0.146484,0.160156,-0.065918,0.071289,0.171875,-0.091309,0.060547,-0.138672,-0.144531,-0.000641,0.148438,-0.069824,0.197266,0.109863,0.077148,0.124023,-0.232422,-0.066406,0.107422,0.017456,-0.057617,-0.012451,-0.068359,-0.070312,0.176758,-0.039062,0.067383,-0.071777,-0.207031,-0.060303,0.034424,-0.0625,-0.19043,-0.069824,-0.3125,0.004639,0.140625,-0.308594,0.107422,0.014954,-0.088867,-0.232422,-0.165039,-0.125977,-0.064453,-0.116699,-0.135742,-0.005066,-0.021729,-0.104492,0.003433,0.164062,0.167969,0.3125,0.104004,-0.474609,-0.235352,0.255859,0.021362,-0.032959,0.05957,0.110352,-0.099609,0.004974,-0.02771,0.324219,-0.185547,0.028198,0.020996,-0.05542,-0.00412,-0.07959,-0.029419,0.130859,0.008179,0.068848,-0.050049,-0.064941,0.392578,-0.078125,-0.036377,0.207031,-0.013672,-0.018555,0.014343,0.142578,0.137695,0.075195,-0.080078,-0.11084,-0.053711,0.043701,-0.12793,-0.072754,-0.216797,-0.037354,-0.084961,-0.172852,0.032715,0.131836,-0.00264,-0.005096,0.139648,-0.196289,-0.206055,0.224609,0.055176,0.285156,0.275391,-0.124512,0.157227,0.070312,-0.121582,0.050781,-0.19043,0.106445,0.000725,-0.129883,0.228516,-0.111328,-0.322266,0.158203,0.353516,-0.021118,0.016235,-0.059814,0.026978,-0.07373,0.016846,-0.004181,-0.034424,-0.178711,-0.002899,-0.183594,0.310547,-0.161133,0.073242,0.112793,0.09082,0.071289,-0.015198,-0.192383,0.166992,-0.076172,-0.275391,0.083008,0.142578,0.289062,0.084473,-0.030396,-0.100098,-0.128906,0.038086,0.095215,0.00386,-0.089355,0.176758,-0.035889,0.163086,-0.03418,-0.126953,0.169922,-0.447266,-0.289062,0.188477,0.07666,0.102051,0.314453,-0.138672,-0.119629,0.269531,-0.03418,0.070801,0.425781,0.002655,-0.067871,-0.192383,-0.251953,0.326172,0.261719,0.097168,0.119629,-0.029175,0.449219,0.382812,-0.053467,-0.125977,-0.28125,0.043457,-0.09668,0.216797,0.120117,0.213867,-0.257812,-0.125,-0.28125,-0.111328,-0.139648,0.194336,-0.141602,-0.051758,-0.191406,0.060303,-0.302734,-0.226562,-0.503906,0.296875,-0.143555,0.023071,-0.132812,-0.229492,-0.086426,0.046875,0.199219,-0.22168,-0.099609,-0.075684,0.085938,0.207031,-0.056885,-0.02832,-0.021118,0.103027,-0.117188,-0.053467,0.083008,-0.04126,-0.038818,0.255859,-0.053467,-0.15918,0.022217,0.02478,0.300781,0.046631,-0.287109,-0.037842,-0.095703,0.172852,-0.061523,-0.008789,-0.095215,-0.228516,0.068848,0.253906,-0.044922,0.090332,0.010864,0.130859,-0.064453,0.064941,-0.059082,-0.083008,-0.217773,-0.114746,-0.04834,0.216797,-0.054199,0.117676,0.024048,-0.044434,-0.011475,-0.121094,-0.071777,-0.028809,0.099609,-0.107422,-0.025635,0.119141,0.053711,0.035645,-0.124023,-0.137695,-0.132812,0.177734,-0.139648,-0.224609,-0.246094,0.216797,-0.125,0.013184,-0.157227,-0.108887,-0.052246,0.244141,-0.146484,-0.09668,0.300781,-0.044678,-0.204102,-0.001884,-0.094238,-0.064453,-0.014709,-0.038818,0.147461,0.022949,0.021606,-0.177734,0.062256,-0.1875,0.143555,0.227539,-0.289062,0.241211,-0.324219,-0.052979,0.016968,-0.047363,-0.031982,-0.065918,0.241211,-0.043701,-0.056152,-0.129883,0.131836,-0.039551,0.129883,0.166992,-0.057617,-0.216797,0.147461,-0.294922,-0.166016,0.090332,-0.068359,0.249023,0.053955,-0.25,0.138672,284


In [None]:
#This is a regression neural network in fastai 

procs_nn = [Categorify]  # Here are processes that fastai applies to the data. I've added categorify which helps deal with categorical data, but in this codes present state this won't do anything as we already dealt with this by creating the embeddings
cont_nn, cat_nn = cont_cat_split(df_w2v, dep_var=dep_var) # Splits data into continuous and categorical. Again, everything should be continuous
train_idx, test_idx = train_test_split(df_w2v.index, test_size=0.2, random_state = 0)
splits = (list(train_idx), list(test_idx))
to_nn = TabularPandas(df_w2v, procs_nn, cat_nn, cont_nn,
                      splits=splits, y_names=dep_var)
dls = to_nn.dataloaders(1024) # 1024 is batch size 
colex_model_1 = tabular_learner(dls, y_range=(0,NUM_LANGUAGE), layers=[500, 500, 500, 100, 100, 100, 100],
                        n_out=1, loss_func=F.mse_loss)   # layers gives the sizes of linear layers we are using in our neural network

In [None]:
colex_model_1.lr_find() # Finds learning rate

In [None]:
colex_model_1.fit_one_cycle(100, 0.003) # Fits model for 100 epochs. fit_one_cycle is an improved variation of fit that includes things like momentum and cosine annealing. 

In [None]:
# test_words allows us to predict p(a , b colexify) from our model for two words a, b

def words2tensor(word1, word2):
  return torch.tensor(np.concatenate((np.array(wv[word1]), np.array(wv[word2]), np.array([0, 0, 0, 0, 0]))), dtype = torch.float32)
def words2df(word1, word2):
  vec1 = np.array(wv[word1])
  vec2 = np.array(wv[word2])
  vec = np.concatenate((vec1, vec2))
  dic = {}
  for i in range(len(vec)):
    dic[i] = [vec[i]]
  return pd.DataFrame.from_dict(dic)
def test_words(word1, word2, model):
  dl = model.dls.test_dl(words2df(word1, word2))
  return model.get_preds(dl=dl)

print(test_words("flee", "trio", colex_net))

(tensor([[1505.2473]]), None)
