In [26]:
import torch
import torch.nn as nn

class LR(nn.Module):
    def __init__(self):
        super().__init__()
        self.ll = nn.Linear(1, 1)
        self.loss_fn = nn.MSELoss()

    def forward(self, x, labels=None):
        preds = self.ll(x)

        if labels is not None:
            loss = self.loss_fn(preds, labels)
            return {"loss": loss, "logits": preds}

        return {"logits": preds}


In [27]:
class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return {
            "x": self.x[idx],
            "labels": self.y[idx],
        }


In [37]:
inp = torch.tensor([[1.],[2.],[3.],[4.]])
out = torch.tensor([[1.],[2.],[3.],[4.]])

dataset = SimpleDataset(inp, out)

inp = torch.tensor([[5.],[6.],[7.],[8.]])
out = torch.tensor([[5.],[6.],[7.],[8.]])

eval_dataset = SimpleDataset(inp, out)


In [40]:
from transformers import TrainerCallback

class LRLossCallback(TrainerCallback):
    def __init__(self):
        self.history = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return

        if "loss" in logs:
            lr = logs.get("learning_rate")
            self.history.append({
                "step": state.global_step,
                "lr": lr,
                "loss": logs["loss"],
            })


In [41]:
from transformers import TrainingArguments

def make_args(lr):
    return TrainingArguments(
        output_dir="./tmp",
        per_device_train_batch_size=4,
        num_train_epochs=10,
        logging_steps=1,
        report_to="none",
        save_strategy="no",
        remove_unused_columns=False,  # ‚Üê IMPORTANT
        eval_strategy="epoch",
        fp16=True,
        learning_rate=2e-5,
        lr_scheduler_type="cosine",
    )



In [39]:
all_loss_list = []

for exp in range(5):
    lr_val = 10 ** (-exp)
    # print("lr =", lr_val)

    model = LR()

    trainer = Trainer(
        model=model,
        args=make_args(lr_val),
        train_dataset=dataset,
        eval_dataset=eval_dataset,
        callbacks=[callback],
    )

    trainer.train()

    # extract logged losses
    losses = [
        log["loss"]
        for log in trainer.state.log_history
        if "loss" in log
    ]

    all_loss_list.append(losses)


lr = 1


Epoch,Training Loss,Validation Loss
1,10.6056,74.322723
2,10.6052,74.320175
3,10.6047,74.317802
4,10.6043,74.315735
5,10.6039,74.314018
6,10.6036,74.312714
7,10.6034,74.311813
8,10.6032,74.311272
9,10.6031,74.311028
10,10.6031,74.310966


lr = 0.1


Epoch,Training Loss,Validation Loss
1,5.0763,25.914953
2,5.076,25.913446
3,5.0757,25.91205
4,5.0754,25.910824
5,5.0751,25.909813
6,5.0749,25.909042
7,5.0747,25.908508
8,5.0746,25.908192
9,5.0745,25.908043
10,5.0745,25.908007


lr = 0.01


Epoch,Training Loss,Validation Loss
1,5.0763,25.914953
2,5.076,25.913446
3,5.0757,25.91205
4,5.0754,25.910824
5,5.0751,25.909813
6,5.0749,25.909042
7,5.0747,25.908508
8,5.0746,25.908192
9,5.0745,25.908043
10,5.0745,25.908007


lr = 0.001


Epoch,Training Loss,Validation Loss
1,5.0763,25.914953
2,5.076,25.913446
3,5.0757,25.91205
4,5.0754,25.910824
5,5.0751,25.909813
6,5.0749,25.909042
7,5.0747,25.908508
8,5.0746,25.908192
9,5.0745,25.908043
10,5.0745,25.908007


lr = 0.0001


Epoch,Training Loss,Validation Loss
1,5.0763,25.914953
2,5.076,25.913446
3,5.0757,25.91205
4,5.0754,25.910824
5,5.0751,25.909813
6,5.0749,25.909042
7,5.0747,25.908508
8,5.0746,25.908192
9,5.0745,25.908043
10,5.0745,25.908007


In [23]:
all_loss_list

[[6.6005,
  1.8453,
  1.2564,
  0.4155,
  0.437,
  0.1807,
  0.2225,
  0.1447,
  0.0515,
  0.056],
 [16.8734,
  14.0048,
  11.654,
  9.7481,
  8.2221,
  7.0196,
  6.0917,
  5.3981,
  4.9062,
  4.5918],
 [16.8734,
  16.5744,
  16.3076,
  16.0723,
  15.8678,
  15.6936,
  15.5492,
  15.4341,
  15.3481,
  15.2909],
 [16.8734,
  16.8434,
  16.8164,
  16.7924,
  16.7714,
  16.7535,
  16.7385,
  16.7266,
  16.7176,
  16.7116],
 [16.8734,
  16.8704,
  16.8677,
  16.8653,
  16.8632,
  16.8614,
  16.8599,
  16.8587,
  16.8578,
  16.8572]]