In [26]:
import torch
import torch.nn as nn

class LR(nn.Module):
    def __init__(self):
        super().__init__()
        self.ll = nn.Linear(1, 1)
        self.loss_fn = nn.MSELoss()

    def forward(self, x, labels=None):
        preds = self.ll(x)

        if labels is not None:
            loss = self.loss_fn(preds, labels)
            return {"loss": loss, "logits": preds}

        return {"logits": preds}


In [27]:
class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return {
            "x": self.x[idx],
            "labels": self.y[idx],
        }


In [37]:
inp = torch.tensor([[1.],[2.],[3.],[4.]])
out = torch.tensor([[1.],[2.],[3.],[4.]])

dataset = SimpleDataset(inp, out)

inp = torch.tensor([[5.],[6.],[7.],[8.]])
out = torch.tensor([[5.],[6.],[7.],[8.]])

eval_dataset = SimpleDataset(inp, out)


In [40]:
from transformers import TrainerCallback

class LRLossCallback(TrainerCallback):
    def __init__(self):
        self.history = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return

        if "loss" in logs:
            lr = logs.get("learning_rate")
            self.history.append({
                "step": state.global_step,
                "lr": lr,
                "loss": logs["loss"],
            })


In [65]:
from transformers import TrainingArguments

def make_args(lr, seed):
    return TrainingArguments(
        output_dir=f"./tmp/run_{exp}",
        per_device_train_batch_size=4,
        num_train_epochs=10,
        logging_steps=1,
        report_to="none",
        save_strategy="no",
        remove_unused_columns=False,  # ‚Üê IMPORTANT
        eval_strategy="epoch",
        fp16=False,
        learning_rate=lr,
        lr_scheduler_type="cosine",
        seed=seed
    )



In [66]:
all_loss_list = []


for exp in range(5):
    callback = LRLossCallback()
    lr_val = 10 ** (-exp)
    print("lr =", lr_val)

    model = LR()

    trainer = Trainer(
        model=model,
        args=make_args(lr_val, seed=exp),
        train_dataset=dataset,
        eval_dataset=eval_dataset,
        callbacks=[callback],
    )

    trainer.train()

    # extract logged losses
    losses = [
        log["loss"]
        for log in trainer.state.log_history
        if "loss" in log
    ]

    all_loss_list.append(losses)

    # history = callback.history
    # print(history)

Epoch,Training Loss,Validation Loss
1,5.0763,6.217173
2,2.0517,4.406288
3,1.5017,0.444595
4,0.0324,0.988045
5,0.1238,0.0041
6,0.0224,0.004147
7,0.0128,0.069165
8,0.0069,0.080134
9,0.0106,0.04292
10,0.0052,0.03078


Epoch,Training Loss,Validation Loss
1,9.4233,35.953953
2,7.3046,27.630583
3,5.498,20.89098
4,4.0524,15.750909
5,2.9657,12.058399
6,2.1978,9.56951
7,1.6891,8.017708
8,1.377,7.157342
9,1.2062,6.775292
10,1.131,6.679074


Epoch,Training Loss,Validation Loss
1,0.2214,0.221167
2,0.1894,0.157624
3,0.1607,0.108485
4,0.1365,0.073108
5,0.117,0.049394
6,0.1022,0.034596
7,0.0917,0.026051
8,0.0848,0.021607
9,0.0809,0.019714
10,0.0791,0.019246


Epoch,Training Loss,Validation Loss
1,19.8959,91.648773
2,19.8631,91.507187
3,19.8312,91.376022
4,19.8016,91.260971
5,19.7757,91.166161
6,19.7543,91.093781
7,19.738,91.043777
8,19.7267,91.013962
9,19.72,91.000153
10,19.7169,90.996613


Epoch,Training Loss,Validation Loss
1,5.8503,31.909046
2,5.8486,31.900692
3,5.8468,31.892946
4,5.8452,31.886147
5,5.8438,31.880545
6,5.8427,31.876263
7,5.8418,31.873306
8,5.8412,31.871542
9,5.8408,31.870724
10,5.8406,31.870514


In [61]:
import torch
print(torch.cuda.is_available())


False
