In [1]:
cd ..

/home/xavier/projects/godatathon_2020


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.model.trainer import RNNModel
from src.model.dataset import NovartisDataset

### Params

In [4]:
input_dim = 3
hidden_dim = 20
num_layers = 1

In [5]:
SEED = 27
LR = 5e-4

In [6]:
pl.seed_everything(SEED)

27

# Data

In [7]:
df = pd.read_csv("data/features/final_features.csv")

In [8]:
df = df.sort_values(["country", "brand", "month_num"])

### Preprocessing

#### Select only cases with 24 months after generic (To remove later)

In [9]:
# Note: In the future, we will compute the loss only on data that we have available for each country/mont
# i.e. If a country only has volume until mont 20, we will pad/ignore the loss of months 21-24
country_brand_post_count = df[df["month_num"] >= 0].groupby(["country", "brand"]).size()

country_brand_post_count.name = "post_months_count"
country_brand_post_count = country_brand_post_count.reset_index()

In [10]:
df = df.merge(country_brand_post_count, on=["country", "brand"], how="right")

In [11]:
# Select only dataset with 24 months after generic
df = df[df["post_months_count"]==24]

In [12]:
# Remove unused column
df = df.drop(columns="post_months_count")

#### Add country-brand column

In [13]:
df["country_brand"] = df["country"] + "-" + df["brand"]

---

# Train/Val Split

### Train

In [14]:
from sklearn.model_selection import train_test_split

In [15]:
country_brands = df["country_brand"].drop_duplicates().values

In [16]:
# Train/Val split
country_brands_train, country_brands_val = train_test_split(country_brands,
                                                            test_size=0.20,
                                                            random_state=SEED)

In [17]:
volume_train = df[df["country_brand"].isin(country_brands_train)].copy()

#### Dataset/DataLoader

In [18]:
ds_train = NovartisDataset(volume_train)
dl_train = DataLoader(ds_train, batch_size=1, num_workers=0, shuffle=True)

In [19]:
for batch in ds_train: break

### Validation

In [20]:
volume_val = df[df["country_brand"].isin(country_brands_val)].copy()

In [21]:
ds_val = NovartisDataset(volume_val)
dl_val = DataLoader(ds_val, batch_size=1, num_workers=0, shuffle=True)

# Lightning

In [22]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss")
early_stopping_callback = pl.callbacks.EarlyStopping(monitor="val_loss")

In [23]:
trainer = pl.Trainer(max_epochs=50, gpus=1, callbacks=[checkpoint_callback, early_stopping_callback])
model = RNNModel(input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers, lr=LR)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [24]:
trainer.fit(model, train_dataloader=dl_train, val_dataloaders=dl_val)


  | Name  | Type    | Params
----------------------------------
0 | model | Seq2Seq | 7.9 K 


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

# Predict

In [21]:
model_path = "lightning_logs/version_5/checkpoints/epoch=49.ckpt"
model_path = "lightning_logs/version_8/checkpoints/epoch=13.ckpt"

In [22]:
model = RNNModel.load_from_checkpoint(model_path)

In [26]:
df = pd.read_csv("data/features/final_features.csv")
df["country_brand"] = df["country"] + "-" + df["brand"]

In [24]:
submissions = pd.read_csv("data/raw/submission_template.csv")
submissions["country_brand"] = submissions["country"] + "-" + submissions["brand"]

In [27]:
# Filter out country/brand in submissions
df_test = df[df["country_brand"].isin(submissions["country_brand"])]

# Sort values
df_test = df_test.sort_values(["country", "brand", "month_num"])

In [28]:
df_test.head()

Unnamed: 0,country,brand,month_num,country_id,brand_id,num_generics,package_id,channel_rate_A,channel_rate_B,channel_rate_C,therapeutic_id,avg_12_volume,max_volume,month_sin,month_cos,volume_norm,country_brand
76478,country_1,brand_121,-101,0,25,0.08,6,0.0,0.017237,0.0,7,35999789.0,38294953.2,0.5,0.8660254,0.002871,country_1-brand_121
76479,country_1,brand_121,-100,0,25,0.08,6,0.0,0.017237,0.0,7,35999789.0,38294953.2,0.866025,0.5,0.022482,country_1-brand_121
76480,country_1,brand_121,-99,0,25,0.08,6,0.0,0.017237,0.0,7,35999789.0,38294953.2,1.0,6.123234000000001e-17,0.037999,country_1-brand_121
76481,country_1,brand_121,-98,0,25,0.08,6,0.0,0.017237,0.0,7,35999789.0,38294953.2,0.866025,-0.5,0.049187,country_1-brand_121
76482,country_1,brand_121,-97,0,25,0.08,6,0.0,0.017237,0.0,7,35999789.0,38294953.2,0.5,-0.8660254,0.06401,country_1-brand_121


### Test

In [29]:
ds_test = NovartisDataset(df_test)
dl_test = DataLoader(ds_test, batch_size=1, num_workers=0)

In [30]:
max_volume_series = df.groupby("country_brand")["max_volume"].unique().apply(lambda x: x.item())

In [31]:
ds_test = NovartisDataset(df_test)

In [33]:
predictions = []

model.eval()
for n, batch in enumerate(tqdm(dl_test)):
    # Unpack batch
    x = batch["x"]
    y = batch["y_norm"]
    avg_12_volume = batch["avg_12_volume"]
    max_volume = batch["max_volume"]

    y_hat = model(x, y)

    y_hat_numpy = y_hat.squeeze(dim=1).detach().numpy()

    for month, vol_pred in enumerate(y_hat_numpy.flatten()):
        
        country, brand = ds_test.group_keys[n]
        
        # Add volume scaling
        volume_scaling = max_volume_series.loc[country + "-" + brand].item()
        
        prediction = {"country": country,
                      "brand": brand,
                      "month_num": month,
                      "pred_95_low": vol_pred * volume_scaling,
                      "prediction": vol_pred * volume_scaling,
                      "pred_95_high": vol_pred * volume_scaling}
        predictions.append(prediction)

100%|██████████| 191/191 [00:06<00:00, 31.44it/s]


In [34]:
df_preds = pd.DataFrame(predictions)
df_preds.head()

Unnamed: 0,country,brand,month_num,pred_95_low,prediction,pred_95_high
0,country_1,brand_121,0,28507690.0,28507690.0,28507690.0
1,country_1,brand_121,1,27824560.0,27824560.0,27824560.0
2,country_1,brand_121,2,26509550.0,26509550.0,26509550.0
3,country_1,brand_121,3,25111910.0,25111910.0,25111910.0
4,country_1,brand_121,4,23779490.0,23779490.0,23779490.0


# Submission

In [None]:
# Add predictions to submissions
merge_cols = ["country", "brand", "month_num"]
final_submissions = submissions[merge_cols].merge(df_preds, on=merge_cols, how="left")
final_submissions.head()

In [None]:
# Overwrite already know volumes to submissions
final_submissions = final_submissions.set_index(["country", "brand", "month_num"])
volume = volume.set_index(["country", "brand", "month_num"])

for idx, _ in final_submissions.iterrows():
    if idx in volume.index:
        final_submissions.loc[idx] = volume.loc[idx, "volume"]

In [None]:
final_submissions = final_submissions.reset_index()

In [None]:
final_submissions.to_csv("data/submissions/sumbission_04.csv", index=False)