In [1]:
!pip install -q --upgrade wandb

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import wandb
import numpy as np
import matplotlib
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
matplotlib.use('Agg')

In [3]:
# Example usage:
TRAIN_PATH = '/kaggle/input/goodreads-books-reviews-290312/goodreads_train.csv'
TEST_PATH = '/kaggle/input/goodreads-books-reviews-290312/goodreads_test.csv'
params = {'WANDB_PROJECT': 'review_classifier',
          'ENTITY': 'lilouuch',
          'CLASSES': {i: c for i, c in enumerate(range(0, 6))},
          'RAW_DATA_AT': 'Goodreads_Books_Review_Rating',
          'PROCESSED_DATA_AT': 'Goodreads_Books_Review_Rating_load'}

In [4]:
run =  wandb.init(project=params['WANDB_PROJECT'], entity=params['ENTITY'], job_type="split")

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [5]:
raw_data_at = run.use_artifact(f"{params['RAW_DATA_AT']}:latest")

In [6]:
raw_data_at.download()

[34m[1mwandb[0m: Downloading large artifact Goodreads_Books_Review_Rating:latest, 2146.55MB. 4 files... 
[34m[1mwandb[0m:   4 of 4 files downloaded.  
Done. 0:0:50.6


'/kaggle/working/artifacts/Goodreads_Books_Review_Rating:v0'

In [7]:
train_df = pd.read_csv('/kaggle/working/artifacts/Goodreads_Books_Review_Rating:v0/train.csv')

In [8]:
train_df['fold'] = -1
# Define the number of folds
n_splits = 10
# Initialize the stratified cross-validator
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Assign fold indices to the 'fold' column
for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df['rating'])):
    train_df.loc[val_idx, 'fold'] = fold

In [9]:
train_df['fold'].value_counts()

fold
4    90000
9    90000
7    90000
3    90000
2    90000
5    90000
1    90000
8    90000
0    90000
6    90000
Name: count, dtype: int64

In [10]:
train_df['Stage'] = 'train'
train_df.loc[train_df.fold == 1, ['Stage']] = 'valid'

In [11]:
train_df['Stage'].value_counts()

Stage
train    810000
valid     90000
Name: count, dtype: int64

In [12]:
valid_df = train_df[train_df['Stage'] == 'valid']

In [13]:
valid_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 90000 entries, 14 to 899995
Data columns (total 13 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   user_id       90000 non-null  object
 1   book_id       90000 non-null  int64 
 2   review_id     90000 non-null  object
 3   rating        90000 non-null  int64 
 4   review_text   90000 non-null  object
 5   date_added    90000 non-null  object
 6   date_updated  90000 non-null  object
 7   read_at       80672 non-null  object
 8   started_at    62593 non-null  object
 9   n_votes       90000 non-null  int64 
 10  n_comments    90000 non-null  int64 
 11  fold          90000 non-null  int64 
 12  Stage         90000 non-null  object
dtypes: int64(5), object(8)
memory usage: 9.6+ MB


In [14]:
artifact = wandb.Artifact(
        "Goodreads_Books_Review_Rating_VAL", 
        type="dataset_valid",
        description="containing valid dataset",
        metadata={"source": "kaggle",
                  "shapes": [valid_df.shape]}
    )

In [15]:
train_df.to_csv('train_val_split.csv', index=False)

In [18]:
artifact.add_file('/kaggle/working/train_val_split.csv', name="train_val_split.csv")

ArtifactManifestEntry(path='train_val_split.csv', digest='8AaEQr7djn1FBG//yVxqOA==', size=1141525157, local_path='/root/.local/share/wandb/artifacts/staging/tmpjt1kbke3')

In [19]:
run.log_artifact(artifact)

<Artifact Goodreads_Books_Review_Rating_VAL>

In [20]:
run.finish()

VBox(children=(Label(value='1088.644 MB of 1088.644 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))