In [1]:
from pathlib import Path
import wandb
import pytorch_lightning as pl
from model import Seq2SeqModel
from dakshina_data import DakshinaDataModule

# checkpoint_reference = "livinNector-academic/dakshina/model-08wxmzak:best"
checkpoint_reference = "livinNector-academic/dakshina/model-njt8zyrr:best"
run = wandb.init(project="dakshina")
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()
model = Seq2SeqModel.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")

datamodule = datamodule = DakshinaDataModule(
    data_dir='dakshina_dataset_v1.0',
    lang_code='ta',
    batch_size=128,
    num_workers=2,
)

datamodule.setup()

result = pl.Trainer().predict(model, dataloaders=datamodule.test_dataloader())

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlivinnector[0m ([33mlivinNector-academic[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   1 of 1 files downloaded.  
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/necto/micromamba/envs/ml/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_p

Predicting: |          | 0/? [00:00<?, ?it/s]

In [2]:
test_data_pred = [
    list(
        zip(
            
                datamodule.source_tokenizer.batch_decode(source.tolist()),
                datamodule.target_tokenizer.batch_decode(target.tolist()),
                datamodule.target_tokenizer.batch_decode(pred),
            
        )
    )
    for (source, target), pred in zip(datamodule.test_dataloader(), result)
]

test_data_pred = [row for batch in test_data_pred for row in batch]


In [3]:
import pandas as pd
pred_df = pd.DataFrame(test_data_pred, columns=['source','target','pred'])
# pred_df.to_csv("predictions_vannila.csv", index=False)
pred_df.to_csv("predictions_attention.csv", index=False)

In [4]:
exact_match_accuracy = (pred_df["target"] == pred_df["pred"]).mean()
print(f"Exact Match Accuracy:{exact_match_accuracy:.04f}")
wandb.log({'test_exact_match_accuracy':exact_match_accuracy})

Exact Match Accuracy:0.4908


In [5]:
import plotly.graph_objects as go
import difflib


def diff_ratio(s1, s2):
    return difflib.SequenceMatcher(None, s1, s2).ratio()

def compute_colors(sim):
    sim = max(0, min(sim, 1.0))
    ratio = sim

    red = int(255 * (1 - ratio))
    green = int(255 * ratio)
    blue = 0

    brightness = (red * 299 + green * 587 + blue * 114) / 1000
    text_color = "black" if brightness > 125 else "white"
    bg_color = f"rgb({red},{green},{blue})"
    return bg_color, text_color

def pred_similarity_plot(df):
    bg_colors = []
    text_colors = []
    df['sim_ratio'] = df.apply(lambda x: diff_ratio(x['target'],x['pred']),axis=1)

    for sim in df["sim_ratio"]:
        bg, text = compute_colors(sim)
        bg_colors.append(bg)
        text_colors.append(text)

    fig = go.Figure(
        data=[
            go.Table(
                header=dict(
                    values=["source", "target", "pred", "similarity"],
                    fill_color="lightgrey",
                    align="left",
                ),
                cells=dict(
                    values=[
                        df["source"],
                        df["target"],
                        df["pred"],
                        df["sim_ratio"].round(2),
                    ],
                    fill_color=[
                        ["white"] * len(df),
                        ["white"] * len(df),
                        bg_colors,
                        ["white"] * len(df),
                    ],
                    font=dict(
                        color=[
                            ["black"] * len(df),
                            ["black"] * len(df),
                            text_colors,
                            ["black"] * len(df),
                        ]
                    ),
                    align="left",
                ),
            )
        ]
    )

    fig.update_layout(title="String Similarity (target vs pred)", width=700, height=600)
    return fig

fig = pred_similarity_plot(pred_df.sample(20))
fig.show()

In [6]:
for i in range(10):
    fig = pred_similarity_plot(pred_df.sample(20))
    wandb.log({"sample_prediction":fig})

In [7]:
wandb.finish()

0,1
test_exact_match_accuracy,▁

0,1
test_exact_match_accuracy,0.49082
