- Notebook modified from https://www.kaggle.com/code/markwijkhuizen/planttraits2024-eda-training-pub.
- Training only, EDA part not included.
- Image model only, tabular data not used.

Modified from HDJOJO's original notebook with SWIN Transformer

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as imageio
import albumentations as A

from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch import nn
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler

import torch
import timm
import glob
import torchmetrics
import time
import psutil
import os

tqdm.pandas()

In [2]:
class Config():
    IMAGE_SIZE = 384 # Sample: [224, 224]
    BACKBONE = 'swin_large_patch4_window12_384.ms_in22k_ft_in1k'
    TARGET_COLUMNS = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
    N_TARGETS = len(TARGET_COLUMNS)
    BATCH_SIZE = 10 # Sample: 96
    LR_MAX = 1e-4
    WEIGHT_DECAY = 0.01
    N_EPOCHS = 6 # Sample: 12
    TRAIN_MODEL = True
    IS_INTERACTIVE = os.environ['KAGGLE_KERNEL_RUN_TYPE'] == 'Interactive'
    
    # Added variables
    NUM_FOLDS = 5
    VALID_FOLD = 0  # Fold of validation data
        
CONFIG = Config()

In [3]:
# Read in training data
train_df = pd.read_csv('/kaggle/input/planttraits2024/train.csv')
train_df['file_path'] = train_df['id'].apply(lambda s: f'/kaggle/input/planttraits2024/train_images/{s}.jpeg')
train_df['jpeg_bytes'] = train_df['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
train_df.to_pickle('train.pkl') # serialize object into string form

  0%|          | 0/55489 [00:00<?, ?it/s]

### Data Filtering

In [4]:
# Sampled training set for faster training
print("Previous length:", len(train_df))
train_df = train_df.sample(frac=0.3, random_state=42)
print("Sampled length:", len(train_df))

Previous length: 55489
Sampled length: 16647


In [5]:
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=CONFIG.NUM_FOLDS, shuffle=True, random_state=42)

# Create separate bin for each traits
for i, trait in enumerate(CONFIG.TARGET_COLUMNS):
    # Determine the bin edges dynamically based on the distribution of traits
    bin_edges = np.percentile(train_df[trait], np.linspace(0, 100, CONFIG.NUM_FOLDS + 1))
    train_df[f"bin_{i}"] = np.digitize(train_df[trait], bin_edges)

# Concatenate the bins into a final bin
train_df["final_bin"] = (
    train_df[[f"bin_{i}" for i in range(CONFIG.N_TARGETS)]]
    .astype(str)
    .agg("".join, axis=1)
)

# Perform the stratified split using final bin
train_df = train_df.reset_index(drop=True)
for fold, (train_idx, valid_idx) in enumerate(skf.split(train_df, train_df["final_bin"])):
    train_df.loc[valid_idx, "fold"] = fold
    
train_df.head()



Unnamed: 0,id,WORLDCLIM_BIO1_annual_mean_temperature,WORLDCLIM_BIO12_annual_precipitation,WORLDCLIM_BIO13.BIO14_delta_precipitation_of_wettest_and_dryest_month,WORLDCLIM_BIO15_precipitation_seasonality,WORLDCLIM_BIO4_temperature_seasonality,WORLDCLIM_BIO7_temperature_annual_range,SOIL_bdod_0.5cm_mean_0.01_deg,SOIL_bdod_100.200cm_mean_0.01_deg,SOIL_bdod_15.30cm_mean_0.01_deg,...,file_path,jpeg_bytes,bin_0,bin_1,bin_2,bin_3,bin_4,bin_5,final_bin,fold
0,174618466,-1.74747,129.25,36.32143,105.541046,1195.218506,44.992859,121,150,133,...,/kaggle/input/planttraits2024/train_images/174...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,2,3,2,4,3,1,232431,1.0
1,118794865,19.717112,2230.391113,212.631104,36.743,37.730415,10.992444,98,109,102,...,/kaggle/input/planttraits2024/train_images/118...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,1,3,5,3,5,5,135355,0.0
2,169048426,11.849193,550.698975,108.540817,70.325554,640.514771,31.989796,129,146,144,...,/kaggle/input/planttraits2024/train_images/169...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,2,4,1,5,3,2,241532,4.0
3,196586748,17.708334,380.75,82.0,100.55294,416.421143,25.200001,146,162,156,...,/kaggle/input/planttraits2024/train_images/196...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,2,5,2,2,3,4,252234,4.0
4,179552188,25.9825,3389.233398,432.033325,46.333939,72.031807,10.663334,106,111,110,...,/kaggle/input/planttraits2024/train_images/179...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,1,2,5,1,5,5,125155,2.0


In [6]:
train = train_df[train_df["fold"] != CONFIG.VALID_FOLD]
valid = train_df[train_df["fold"] == CONFIG.VALID_FOLD] # Fold 0 is validation
train[CONFIG.TARGET_COLUMNS + ["fold"]].describe()

Unnamed: 0,X4_mean,X11_mean,X18_mean,X50_mean,X26_mean,X3112_mean,fold
count,13317.0,13317.0,13317.0,13317.0,13317.0,13317.0,13317.0
mean,0.524384,245.9094,61353.2,14.014289,4089.621,1628823.0,2.499887
std,0.174708,18434.69,4082940.0,1384.488946,302886.6,186823300.0,1.11811
min,-1.623941,6.78e-05,3.2e-08,9.7e-05,0.0001790715,7.69e-08,1.0
25%,0.413418,10.53573,0.3104459,1.180538,0.5688102,252.5172,1.0
50%,0.509781,15.06442,0.7112426,1.486108,2.578003,711.2605,2.0
75%,0.622163,19.54607,3.637176,1.927618,14.6019,2146.108,3.0
max,4.475172,1504254.0,272049400.0,159759.8977,31065550.0,21559110000.0,4.0


In [7]:
class PlantDataPreProcess:
    lower_quantile = 0.005
    upper_quantile = 0.995
    log_transform = np.log10

In [8]:
# Filter data
print("Num samples before filtering:", len(train))

for trait in CONFIG.TARGET_COLUMNS:
    lower_bound = train[trait].quantile(PlantDataPreProcess.lower_quantile)
    upper_bound = train[trait].quantile(PlantDataPreProcess.upper_quantile)
    train = train[(train[trait] >= lower_bound) & (train[trait] <= upper_bound)]
    
print("Num samples After filtering:", len(train))
train[CONFIG.TARGET_COLUMNS].describe()

Num samples before filtering: 13317
Num samples After filtering: 12540


Unnamed: 0,X4_mean,X11_mean,X18_mean,X50_mean,X26_mean,X3112_mean
count,12540.0,12540.0,12540.0,12540.0,12540.0,12540.0
mean,0.523047,15.730888,3.26438,1.62325,42.969495,1849.32086
std,0.145291,7.60301,5.435135,0.641962,171.425975,3080.443482
min,0.188629,2.800125,0.032818,0.494606,0.006891,8.965067
25%,0.414165,10.715475,0.319718,1.191952,0.59592,263.162364
50%,0.51015,15.064569,0.711372,1.488062,2.596759,716.241347
75%,0.62133,19.403532,3.447378,1.914411,13.985242,2091.202458
max,0.962373,56.307174,31.872379,4.631433,2430.687341,28842.265893


In [9]:
# Log10 transformation for all traits except X4
LOG_FEATURES = ['X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
y_train = train[CONFIG.TARGET_COLUMNS]

for skewed_trait in LOG_FEATURES:
    y_train.loc[:, skewed_trait] = y_train[skewed_trait].apply(PlantDataPreProcess.log_transform)

y_train.describe()

Unnamed: 0,X4_mean,X11_mean,X18_mean,X50_mean,X26_mean,X3112_mean
count,12540.0,12540.0,12540.0,12540.0,12540.0,12540.0
mean,0.523047,1.14424,0.004508,0.179229,0.461365,2.848448
std,0.145291,0.223549,0.672241,0.16386,1.035198,0.650915
min,0.188629,0.447177,-1.483884,-0.305741,-2.161744,0.952554
25%,0.414165,1.030011,-0.495233,0.076259,-0.224812,2.420224
50%,0.51015,1.177957,-0.147903,0.172621,0.414432,2.855059
75%,0.62133,1.287881,0.537489,0.282035,1.14567,3.320396
max,0.962373,1.750564,1.503414,0.665715,3.385729,4.460029


In [10]:
# Normalize to mean = 0, std dev = 1
from sklearn.preprocessing import StandardScaler

SCALER = StandardScaler()
y_train = SCALER.fit_transform(y_train)

# y_train_df = pd.DataFrame(y_train, columns=CONFIG.TARGET_COLUMNS)
# y_train_df.describe()

### SWIN Transformer Data Load

In [11]:
# Previous filtering by HDJOJO
# Keep only data that is in range 0.005 to 0.985
# for column in CONFIG.TARGET_COLUMNS:
#     lower_quantile = train[column].quantile(0.005)
#     upper_quantile = train[column].quantile(0.985)  
#     train = train[(train[column] >= lower_quantile) & (train[column] <= upper_quantile)]

CONFIG.N_TRAIN_SAMPLES = len(train)
CONFIG.N_STEPS_PER_EPOCH = (CONFIG.N_TRAIN_SAMPLES // CONFIG.BATCH_SIZE)
CONFIG.N_STEPS = CONFIG.N_STEPS_PER_EPOCH * CONFIG.N_EPOCHS + 1

test = pd.read_csv('/kaggle/input/planttraits2024/test.csv')
test['file_path'] = test['id'].apply(lambda s: f'/kaggle/input/planttraits2024/test_images/{s}.jpeg')
test['jpeg_bytes'] = test['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
test.to_pickle('test.pkl')

print('N_TRAIN_SAMPLES:', len(train), 'N_TEST_SAMPLES:', len(test))

  0%|          | 0/6545 [00:00<?, ?it/s]

N_TRAIN_SAMPLES: 12540 N_TEST_SAMPLES: 6545


In [12]:
# Previous log scaling and normalization
# LOG_FEATURES = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']

# y_train = np.zeros_like(train[CONFIG.TARGET_COLUMNS], dtype=np.float32)
# for target_idx, target in enumerate(CONFIG.TARGET_COLUMNS):
#     v = train[target].values
#     if target in LOG_FEATURES:
#         v = np.log10(v) # take log10 base of all values
#     y_train[:, target_idx] = v # store log10 of target values

# SCALER = StandardScaler() # remove the mean and scale to unit variance.
# y_train = SCALER.fit_transform(y_train)

In [13]:
print("Train len:", len(train))
print("y_train len", len(y_train))

Train len: 12540
y_train len 12540


In [14]:
# Where did values come from?
# Likely Mean/std dev for each channel - Check! (only for train though)
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

TRAIN_TRANSFORMS = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomSizedCrop(
            [448, 512],
            CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE, w2h_ratio=1.0, p=0.75),
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.25),
        A.ImageCompression(quality_lower=85, quality_upper=100, p=0.25),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

VALID_TRANSFORMS = A.Compose([
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

TEST_TRANSFORMS = A.Compose([
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

class Dataset(Dataset):
    def __init__(self, X_jpeg_bytes, y, transforms=None):
        self.X_jpeg_bytes = X_jpeg_bytes
        self.y = y
        self.transforms = transforms

    def __len__(self):
        return len(self.X_jpeg_bytes)

    def __getitem__(self, index):
        X_sample = self.transforms(
            image=imageio.imread(self.X_jpeg_bytes[index]),
        )['image']
        y_sample = self.y[index]
        
        return X_sample, y_sample

train_dataset = Dataset(
    train['jpeg_bytes'].values,
    y_train,
    TRAIN_TRANSFORMS,
)

train_dataloader = DataLoader(
        train_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=psutil.cpu_count(),
)


valid_dataset = Dataset(
    valid['jpeg_bytes'].values,
    valid['id'].values,
    VALID_TRANSFORMS,
)

test_dataset = Dataset(
    test['jpeg_bytes'].values,
    test['id'].values,
    TEST_TRANSFORMS,
)

In [15]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
                CONFIG.BACKBONE,
                num_classes=CONFIG.N_TARGETS,
                pretrained=True)  # Use pretrained SWIN Transformer model
        
    def forward(self, inputs):
        return self.backbone(inputs)

model = Model()
model = model.to('cuda')
print(model)

model.safetensors:   0%|          | 0.00/801M [00:00<?, ?B/s]

Model(
  (backbone): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    )
    (layers): Sequential(
      (0): SwinTransformerStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=192, out_features=576, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=192, out_features=192, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path1): Identity()
            (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=192, out_features=768, bias=

In [16]:
def get_lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=CONFIG.LR_MAX,
        total_steps=CONFIG.N_STEPS,
        pct_start=0.1,
        anneal_strategy='cos',
        div_factor=1e1,
        final_div_factor=1e1,
    )

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val):
        self.sum += val.sum()
        self.count += val.numel()
        self.avg = self.sum / self.count

MAE = torchmetrics.regression.MeanAbsoluteError().to('cuda')
R2 = torchmetrics.regression.R2Score(num_outputs=CONFIG.N_TARGETS, multioutput='uniform_average').to('cuda')
LOSS = AverageMeter()

Y_MEAN = torch.tensor(y_train).mean(dim=0).to('cuda')
EPS = torch.tensor([1e-6]).to('cuda')

def r2_loss(y_pred, y_true):
    ss_res = torch.sum((y_true - y_pred)**2, dim=0)
    ss_total = torch.sum((y_true - Y_MEAN)**2, dim=0)
    ss_total = torch.maximum(ss_total, EPS)
    r2 = torch.mean(ss_res / ss_total)
    return r2

# How is this R2 Loss?
LOSS_FN = nn.SmoothL1Loss() # r2_loss

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=CONFIG.LR_MAX,
    weight_decay=CONFIG.WEIGHT_DECAY,
)

LR_SCHEDULER = get_lr_scheduler(optimizer)

In [17]:
print("Start Training:")
for epoch in range(CONFIG.N_EPOCHS):
    MAE.reset()
    R2.reset()
    LOSS.reset()
    model.train()
        
    for step, (X_batch, y_true) in enumerate(train_dataloader):
        X_batch = X_batch.to('cuda')
        y_true = y_true.to('cuda')
        t_start = time.perf_counter_ns()
        y_pred = model(X_batch)
        loss = LOSS_FN(y_pred, y_true)
        LOSS.update(loss)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        LR_SCHEDULER.step()
        MAE.update(y_pred, y_true)
        R2.update(y_pred, y_true)
            
        if not CONFIG.IS_INTERACTIVE and (step+1) == CONFIG.N_STEPS_PER_EPOCH:
            print(
                f'EPOCH {epoch+1:02d}, {step+1:04d}/{CONFIG.N_STEPS_PER_EPOCH} | ' + 
                f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
            )
        elif CONFIG.IS_INTERACTIVE:
            print(
                f'\rEPOCH {epoch+1:02d}, {step+1:04d}/{CONFIG.N_STEPS_PER_EPOCH} | ' + 
                f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                end='\n' if (step + 1) == CONFIG.N_STEPS_PER_EPOCH else '', flush=True,
            )

torch.save(model, 'model.pth')

Start Training:
EPOCH 01, 1254/1254 | loss: 0.2918, mae: 0.6253, r2: 0.3386, step: 1.367s, lr: 9.87e-05
EPOCH 02, 1254/1254 | loss: 0.2366, mae: 0.5490, r2: 0.4752, step: 1.247s, lr: 8.45e-05
EPOCH 03, 1254/1254 | loss: 0.1826, mae: 0.4713, r2: 0.6044, step: 1.247s, lr: 5.91e-05
EPOCH 04, 1254/1254 | loss: 0.1270, mae: 0.3840, r2: 0.7320, step: 1.251s, lr: 3.09e-05
EPOCH 05, 1254/1254 | loss: 0.0810, mae: 0.3016, r2: 0.8329, step: 1.245s, lr: 9.14e-06
EPOCH 06, 1254/1254 | loss: 0.0582, mae: 0.2542, r2: 0.8812, step: 1.253s, lr: 1.00e-06


In [18]:
# Validate on validation set
VALID_ROWS = []
model.eval()

for X_sample_valid, valid_id in tqdm(valid_dataset):
    with torch.no_grad():
        y_pred = model(X_sample_valid.unsqueeze(0).to('cuda')).detach().cpu().numpy()
    
    y_pred = SCALER.inverse_transform(y_pred).squeeze()
    row = {'id': valid_id}
    
    for k, v in zip(CONFIG.TARGET_COLUMNS, y_pred):
        if k in LOG_FEATURES:
            row[k] = 10 ** v
        else:
            row[k] = v

    VALID_ROWS.append(row)
    
valid_predict_df = pd.DataFrame(VALID_ROWS)
print(valid_predict_df.head())

  0%|          | 0/3330 [00:00<?, ?it/s]

          id   X4_mean   X11_mean   X18_mean  X50_mean    X26_mean  \
0  118794865  0.646986  13.482144  12.652718  1.465857   26.648741   
1  185845871  0.792908   4.926701  16.500450  2.439518    5.033408   
2  196227508  0.565385  18.445932  20.561461  1.508761   22.223874   
3  159217870  0.668195  10.570216  16.360542  1.695614  132.924201   
4  168645446  0.593980  15.981668   0.664253  1.756311   22.358055   

    X3112_mean  
0  4050.818206  
1  1955.455083  
2  6607.579114  
3  3133.482194  
4  1193.671285  


In [19]:
# valid_y_true
print(valid[['id'] + CONFIG.TARGET_COLUMNS].head())
valid_y_true = torch.tensor(valid[CONFIG.TARGET_COLUMNS].to_numpy()).to('cuda')

           id   X4_mean   X11_mean   X18_mean  X50_mean    X26_mean  \
1   118794865  0.302113  14.513728   9.278336  1.558418  111.584412   
5   185845871  0.594929   5.427188  26.591154  2.306473   16.550158   
12  196227508  0.636018   6.974774  26.726731  1.419748    0.151257   
17  159217870  0.529810   4.796313  10.377317  1.498278  466.730586   
21  168645446  0.658024  16.052974   0.831718  2.012095   14.677592   

     X3112_mean  
1   4291.222335  
5   2274.619747  
12   244.107546  
17  1614.541606  
21  1433.747650  


In [20]:
# Evaluate valid scores
valid_y_pred = torch.tensor(valid_predict_df[CONFIG.TARGET_COLUMNS].to_numpy()).to('cuda')

with torch.no_grad():
    # Calculate R2 Loss
    print("Validation R2 Loss (using r2_loss):", r2_loss(valid_y_pred, valid_y_true))

    # Loss function (smooth L1 loss)
    valid_loss = LOSS_FN(valid_y_pred, valid_y_true)
    print("Validation loss (Smooth L1 loss): ", valid_loss)

Validation R2 Loss (using r2_loss): tensor(0.8040, device='cuda:0', dtype=torch.float64)
Validation loss (Smooth L1 loss):  tensor(5937.1850, device='cuda:0', dtype=torch.float64)


In [21]:
# VALID_Y_MEAN = torch.tensor(y_train).mean(dim=0).to('cuda')

# def r2_loss_valid(y_pred, y_true):
#     ss_res = torch.sum((y_true - y_pred)**2, dim=0)
#     ss_total = torch.sum((y_true - VALID_Y_MEAN)**2, dim=0)
#     ss_total = torch.maximum(ss_total, torch.tensor([1e-6]))
#     r2 = torch.mean(ss_res / ss_total)
#     return r2

# print("R2 Score valid (using r2_loss_valid):", 1 - r2_loss_valid(valid_y_pred, valid_y_true))

In [22]:
# Scratch code to test R2 loss: random produced around R2 score = -92
# v_len = len(valid_y_true)
# train_y_true = torch.tensor(train[0:v_len][CONFIG.TARGET_COLUMNS].to_numpy())
# print("Train and valid R2 score:", 1 - r2_loss_valid(valid_y_true, train_y_true))

MAE_valid = torchmetrics.regression.MeanAbsoluteError().to('cuda')
R2_valid = torchmetrics.regression.R2Score(num_outputs=CONFIG.N_TARGETS, multioutput='uniform_average').to('cuda')

print("Torch R2 valid:", R2_valid(valid_y_pred, valid_y_true))
print("Torch MAE valid:", MAE_valid(valid_y_pred, valid_y_true))

Torch R2 valid: tensor(0.0530, device='cuda:0')
Torch MAE valid: tensor(5937.5181, device='cuda:0')


In [23]:
# Predict on test set
SUBMISSION_ROWS = []
model.eval()

for X_sample_test, test_id in tqdm(test_dataset):
    with torch.no_grad():
        y_pred = model(X_sample_test.unsqueeze(0).to('cuda')).detach().cpu().numpy()
    
    y_pred = SCALER.inverse_transform(y_pred).squeeze()
    row = {'id': test_id}
    
    for k, v in zip(CONFIG.TARGET_COLUMNS, y_pred):
        if k in LOG_FEATURES:
            row[k.replace('_mean', '')] = 10 ** v
        else:
            row[k.replace('_mean', '')] = v

    SUBMISSION_ROWS.append(row)
    
submission_df = pd.DataFrame(SUBMISSION_ROWS)
print(submission_df.head())
submission_df.to_csv('submission.csv', index=False)
print("Submit!")

  0%|          | 0/6545 [00:00<?, ?it/s]

          id        X4        X11       X18       X50        X26        X3112
0  201238668  0.534756   9.409504  0.957096  1.647736   1.690741   246.899895
1  202310319  0.584410  15.445874  0.437605  1.274473   0.350256   993.442084
2  202604412  0.642294  13.634584  0.865754  1.881790  12.575053   257.690118
3  201353439  0.530408  19.785141  0.202842  1.149939   0.631471  1296.487020
4  195351745  0.502907  10.823110  0.204258  1.483025   1.341975   125.982876
Submit!
