In [1]:
import sys
sys.path.append("../")

from fgi import *
from torch import nn, empty, softmax, randn, cat
from torch import optim, no_grad
from torchmetrics import Accuracy
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

In [2]:
class Digit:
    def __init__(self, *args, **kwargs):
        pass

In [None]:
class ClassificationDigit(NonCodeProblem):
    def __init__(self, _id, *args, **kwargs):
        super().__init__(_id, *args, **kwargs)
        self._r = ImageRepresent(img_shape=(1, 28, 28), patch_size=7, num_heads=1, phi_dim=128)
        self._enhance = EnhanceRepresentUnit(phi_dim=128, dropout=0.2)
        self._property = PropertyUnit(phi_dim=128, dropout=0.2)
        self._specialized = nn.Linear(128, 10)

        self._update_additional_infor()
        
    def forward(self, x, allow_cluster : bool = False,*args, **kwargs):
        q1 = self._r(x)
        q1 = self._enhance(q1)

        q2 = self._property(q1)
        q2 = self._specialized(q1 + q2)
        
        q2 = softmax(q2, dim = 1)
        if allow_cluster:
            return q2, self._enhance._memory(q1)
        return q2

    @property
    def _as_object(self):
        return Digit
    
    def recognize_unknown(self, x, *args, **kwargs):
        """
        Lấy kết quả embedding nháp
        """
        x = self._r(x)
        x = self._enhance._memory(x)
        return x

In [None]:
class DigitLearner(LightningLearner):
    def __init__(self, problem, *args, **kwargs):
        super().__init__(problem, *args, **kwargs)
        self._loss_fn = nn.CrossEntropyLoss()
        self._taccuracy = Accuracy(task='multiclass', num_classes=10)
        self._vaccuracy = Accuracy(task='multiclass', num_classes=10)
        self.example_input_array = randn((1, 1, 28, 28))
    
    def _aggerate_loss(self, y_hat, y, *args, **kwargs):
        return self._loss_fn(y_hat, y)
    
    def configure_optimizers(self):
        # LR scheduler
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(self._optimizer, mode="min", patience=3)
        return { "optimizer" : self._optimizer, "lr_scheduler" : scheduler, "monitor" : "val_loss" }
    
    def training_step(self, batch, batch_idx, *args, **kwargs):
        x, y = batch
        y_hat = self(x)
        loss = self._aggerate_loss(y_hat, y)
        self._taccuracy.update(y_hat, y)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        self.log("train_acc", self._taccuracy, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        return loss
    
    def on_test_batch_end(self, batch, batch_idx, dataloader_idx = 0):
        lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log('end_lr', lr, logger=True, prog_bar=True, on_epoch=True)

    def validation_step(self, batch, batch_idx, *args, **kwargs):
        x, y = batch
        y_hat = self(x)
        loss = self._aggerate_loss(y_hat, y)
        self._vaccuracy.update(y_hat, y)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger=True)
        self.log("val_acc", self._vaccuracy, prog_bar=True, on_step=False, on_epoch=True, logger=True)
        return loss
    
    def on_train_end(self):
        max_examples = 1000  # Giới hạn để tránh log quá nặng
        collected = 0

        all_images = []
        all_labels = []
        all_embeddings = []

        self.eval()
        with no_grad():
            for batch in self.trainer.val_dataloaders:
                x, y = batch
                x = x.to(self.device)
                y = y.to(self.device)

                # Encode ảnh thành vector đặc trưng (ví dụ: self._problem.encode)
                embedding = self._problem.recognize_unknown(x)

                all_images.append(x.cpu())
                all_labels.append(y.cpu())
                all_embeddings.append(embedding.cpu())

                collected += x.size(0)
                if collected >= max_examples:
                    break  # Dừng sớm nếu vượt quá max_examples

        # Nối lại và cắt đúng số lượng
        all_images = cat(all_images, dim=0)[:max_examples]
        all_labels = cat(all_labels, dim=0)[:max_examples]
        all_embeddings = cat(all_embeddings, dim=0)[:max_examples]

        # Ghi embedding lên TensorBoard
        self.logger.experiment.add_embedding(
            mat=all_embeddings,
            metadata=[str(label.item()) for label in all_labels],
            label_img=all_images,
            global_step=self.global_step,
            tag="final_embedding"
        )
    
    def test_step(self, batch, batch_idx, *args, **kwargs):
        x, y = batch
        y_hat = self(x)
        loss = self._aggerate_loss(y_hat, y)
        self._vaccuracy.update(y_hat, y)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger=True)
        self.log("test_acc", self._vaccuracy, prog_bar=True, on_step=False, on_epoch=True, logger=True)


In [5]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=False, num_workers=7, persistent_workers=True)

In [6]:
solver = ClassificationDigit("digit_classification")
learner = DigitLearner(solver)
for _id in learner.learnable.keys():
    learner.learnable = (_id, True)
logger = TensorBoardLogger("digit_problem", name="exp", log_graph=True)
early = EarlyStopping("val_loss", mode='min', patience=8, verbose=True)
checkpoint = ModelCheckpoint(filename="best", monitor="val_loss", verbose=True, mode="min")
learner.compile(optim.SGD, "")

In [7]:
x = randn((32, 1, 28, 28))
y = solver(x)
print(y.shape)
print(y)

torch.Size([32, 10])
tensor([[0.0797, 0.0840, 0.1123, 0.0880, 0.1262, 0.1077, 0.1108, 0.1400, 0.0960,
         0.0553],
        [0.0810, 0.0841, 0.1153, 0.0887, 0.1306, 0.1086, 0.1067, 0.1347, 0.0935,
         0.0569],
        [0.0804, 0.0878, 0.1111, 0.0882, 0.1338, 0.1077, 0.1087, 0.1318, 0.0952,
         0.0552],
        [0.0794, 0.0885, 0.1121, 0.0884, 0.1304, 0.1059, 0.1073, 0.1341, 0.0959,
         0.0580],
        [0.0807, 0.0852, 0.1167, 0.0891, 0.1323, 0.1045, 0.1091, 0.1332, 0.0950,
         0.0542],
        [0.0806, 0.0837, 0.1170, 0.0899, 0.1284, 0.1080, 0.1065, 0.1380, 0.0930,
         0.0550],
        [0.0788, 0.0847, 0.1142, 0.0871, 0.1320, 0.1083, 0.1106, 0.1336, 0.0941,
         0.0567],
        [0.0796, 0.0869, 0.1149, 0.0872, 0.1281, 0.1067, 0.1089, 0.1376, 0.0947,
         0.0554],
        [0.0782, 0.0834, 0.1137, 0.0904, 0.1293, 0.1083, 0.1070, 0.1385, 0.0948,
         0.0565],
        [0.0807, 0.0840, 0.1176, 0.0869, 0.1349, 0.1061, 0.1087, 0.1352, 0.0919,
       

In [8]:
trainer = Trainer(accelerator="auto", max_epochs=20, logger=logger, callbacks=[early, checkpoint])
trainer.fit(learner, train_loader, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params | Mode  | In sizes | Out sizes
----------------------------------------------------------------------------------
0 | _problem   | ClassificationDigit | 109 K  | train | ?        | ?        
1 | _loss_fn   | CrossEntropyLoss    | 0      | train | ?        | ?        
2 | _taccuracy | MulticlassAccuracy  | 0      | train | ?        | ?        
3 | _vaccuracy | MulticlassAccuracy  | 0      | train | ?        | ?        
----------------------------------------------------------------------------------
109 K     Trainable params
0         Non-trainable params
109 K     Total params
0.437     Total estimated model params size (MB)
21        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 1875/1875 [01:08<00:00, 27.44it/s, v_num=2, train_loss_step=1.970, val_loss=2.040, val_acc=0.430, train_loss_epoch=2.180, train_acc_epoch=0.270]

Metric val_loss improved. New best score: 2.039
Epoch 0, global step 1875: 'val_loss' reached 2.03914 (best 2.03914), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 1: 100%|██████████| 1875/1875 [00:45<00:00, 41.08it/s, v_num=2, train_loss_step=1.930, val_loss=1.920, val_acc=0.549, train_loss_epoch=1.960, train_acc_epoch=0.507]

Metric val_loss improved by 0.121 >= min_delta = 0.0. New best score: 1.919
Epoch 1, global step 3750: 'val_loss' reached 1.91855 (best 1.91855), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 2: 100%|██████████| 1875/1875 [00:54<00:00, 34.18it/s, v_num=2, train_loss_step=1.780, val_loss=1.880, val_acc=0.586, train_loss_epoch=1.900, train_acc_epoch=0.563]

Metric val_loss improved by 0.041 >= min_delta = 0.0. New best score: 1.878
Epoch 2, global step 5625: 'val_loss' reached 1.87782 (best 1.87782), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 3: 100%|██████████| 1875/1875 [00:54<00:00, 34.47it/s, v_num=2, train_loss_step=1.720, val_loss=1.880, val_acc=0.584, train_loss_epoch=1.880, train_acc_epoch=0.581]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.877
Epoch 3, global step 7500: 'val_loss' reached 1.87669 (best 1.87669), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 4: 100%|██████████| 1875/1875 [00:53<00:00, 34.78it/s, v_num=2, train_loss_step=1.930, val_loss=1.860, val_acc=0.605, train_loss_epoch=1.870, train_acc_epoch=0.589]

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 1.859
Epoch 4, global step 9375: 'val_loss' reached 1.85887 (best 1.85887), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 5: 100%|██████████| 1875/1875 [00:53<00:00, 34.91it/s, v_num=2, train_loss_step=1.860, val_loss=1.880, val_acc=0.586, train_loss_epoch=1.860, train_acc_epoch=0.604]

Epoch 5, global step 11250: 'val_loss' was not in top 1


Epoch 6: 100%|██████████| 1875/1875 [00:55<00:00, 33.79it/s, v_num=2, train_loss_step=1.840, val_loss=1.860, val_acc=0.606, train_loss_epoch=1.850, train_acc_epoch=0.613]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 1.856
Epoch 6, global step 13125: 'val_loss' reached 1.85583 (best 1.85583), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 7: 100%|██████████| 1875/1875 [00:58<00:00, 32.07it/s, v_num=2, train_loss_step=1.730, val_loss=1.860, val_acc=0.599, train_loss_epoch=1.840, train_acc_epoch=0.621]

Epoch 7, global step 15000: 'val_loss' was not in top 1


Epoch 8: 100%|██████████| 1875/1875 [00:55<00:00, 33.65it/s, v_num=2, train_loss_step=1.880, val_loss=1.820, val_acc=0.640, train_loss_epoch=1.830, train_acc_epoch=0.631]

Metric val_loss improved by 0.034 >= min_delta = 0.0. New best score: 1.822
Epoch 8, global step 16875: 'val_loss' reached 1.82169 (best 1.82169), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 9: 100%|██████████| 1875/1875 [00:55<00:00, 33.92it/s, v_num=2, train_loss_step=1.710, val_loss=1.810, val_acc=0.645, train_loss_epoch=1.820, train_acc_epoch=0.636]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 1.815
Epoch 9, global step 18750: 'val_loss' reached 1.81494 (best 1.81494), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 10: 100%|██████████| 1875/1875 [00:51<00:00, 36.26it/s, v_num=2, train_loss_step=1.740, val_loss=1.810, val_acc=0.648, train_loss_epoch=1.820, train_acc_epoch=0.645]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 1.813
Epoch 10, global step 20625: 'val_loss' reached 1.81292 (best 1.81292), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 11: 100%|██████████| 1875/1875 [00:52<00:00, 35.65it/s, v_num=2, train_loss_step=1.690, val_loss=1.820, val_acc=0.642, train_loss_epoch=1.810, train_acc_epoch=0.649]

Epoch 11, global step 22500: 'val_loss' was not in top 1


Epoch 12: 100%|██████████| 1875/1875 [00:52<00:00, 35.62it/s, v_num=2, train_loss_step=1.780, val_loss=1.830, val_acc=0.627, train_loss_epoch=1.810, train_acc_epoch=0.655]

Epoch 12, global step 24375: 'val_loss' was not in top 1


Epoch 13: 100%|██████████| 1875/1875 [00:57<00:00, 32.42it/s, v_num=2, train_loss_step=1.920, val_loss=1.790, val_acc=0.673, train_loss_epoch=1.800, train_acc_epoch=0.661]

Metric val_loss improved by 0.026 >= min_delta = 0.0. New best score: 1.787
Epoch 13, global step 26250: 'val_loss' reached 1.78740 (best 1.78740), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 14: 100%|██████████| 1875/1875 [00:51<00:00, 36.41it/s, v_num=2, train_loss_step=1.930, val_loss=1.800, val_acc=0.662, train_loss_epoch=1.800, train_acc_epoch=0.664]

Epoch 14, global step 28125: 'val_loss' was not in top 1


Epoch 15: 100%|██████████| 1875/1875 [00:51<00:00, 36.55it/s, v_num=2, train_loss_step=1.740, val_loss=1.780, val_acc=0.677, train_loss_epoch=1.790, train_acc_epoch=0.671]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 1.784
Epoch 15, global step 30000: 'val_loss' reached 1.78387 (best 1.78387), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 16: 100%|██████████| 1875/1875 [00:51<00:00, 36.23it/s, v_num=2, train_loss_step=1.900, val_loss=1.790, val_acc=0.674, train_loss_epoch=1.790, train_acc_epoch=0.673]

Epoch 16, global step 31875: 'val_loss' was not in top 1


Epoch 17: 100%|██████████| 1875/1875 [00:49<00:00, 38.16it/s, v_num=2, train_loss_step=1.830, val_loss=1.790, val_acc=0.668, train_loss_epoch=1.780, train_acc_epoch=0.675]

Epoch 17, global step 33750: 'val_loss' was not in top 1


Epoch 18: 100%|██████████| 1875/1875 [00:53<00:00, 35.11it/s, v_num=2, train_loss_step=1.790, val_loss=1.770, val_acc=0.689, train_loss_epoch=1.780, train_acc_epoch=0.678]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 1.772
Epoch 18, global step 35625: 'val_loss' reached 1.77246 (best 1.77246), saving model to 'digit_problem\\exp\\version_2\\checkpoints\\best.ckpt' as top 1


Epoch 19: 100%|██████████| 1875/1875 [00:46<00:00, 40.08it/s, v_num=2, train_loss_step=1.720, val_loss=1.780, val_acc=0.680, train_loss_epoch=1.780, train_acc_epoch=0.680]

Epoch 19, global step 37500: 'val_loss' was not in top 1
`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 1875/1875 [00:46<00:00, 40.06it/s, v_num=2, train_loss_step=1.720, val_loss=1.780, val_acc=0.680, train_loss_epoch=1.780, train_acc_epoch=0.680]


In [9]:
def release_gpu():
    pass

In [10]:
# Thử nghiệm, 
x = randn((1, 1, 28, 28))
y = solver.recognize_unknown(x)
print(y)

tensor([[ 0.0954, -0.0140,  0.0084,  0.0224,  0.1769,  0.0653,  0.0523,  0.0893,
         -0.0787, -0.0988,  0.1298,  0.0851,  0.0388,  0.0333,  0.0825,  0.0567,
          0.0262, -0.0435,  0.0611,  0.0188, -0.0274,  0.1573,  0.1744,  0.0270,
          0.0371, -0.0164, -0.0905, -0.0749,  0.0207, -0.0882,  0.0025,  0.0393,
         -0.1258, -0.1198, -0.0268,  0.0048, -0.0498,  0.0643,  0.1373, -0.0552,
         -0.0726,  0.1049,  0.0723, -0.0069,  0.1080,  0.0289, -0.0128, -0.0730,
          0.0257, -0.1121,  0.1582, -0.1215, -0.0264, -0.1740,  0.0367,  0.0222,
         -0.0128, -0.0140,  0.0477,  0.0984, -0.0434,  0.0044, -0.0108,  0.0329,
          0.0676, -0.0560,  0.0680, -0.0491, -0.0177,  0.1825, -0.0736,  0.0491,
          0.1740, -0.0237,  0.0150, -0.0125, -0.0224,  0.0714,  0.0304, -0.0910,
          0.0946, -0.0120, -0.0351,  0.0288, -0.1054,  0.1121,  0.2064,  0.0781,
         -0.0396,  0.1400, -0.0542,  0.0227,  0.0781, -0.0866,  0.0435,  0.0323,
         -0.0536,  0.1005,  

In [11]:
solver._property._projection.weight

Parameter containing:
tensor([[-0.8497],
        [-0.1649],
        [-0.6671],
        [-0.7005],
        [-1.1308],
        [-0.8200],
        [ 1.0125],
        [-0.0162],
        [-0.7855],
        [-0.3339],
        [-0.6511],
        [ 0.7315],
        [ 0.8304],
        [ 0.9673],
        [-0.4510],
        [-0.8589],
        [ 0.5375],
        [-0.5861],
        [-0.2845],
        [-1.0333],
        [-0.4609],
        [ 0.9132],
        [ 0.2548],
        [-1.0764],
        [ 0.8823],
        [-0.2858],
        [-0.7323],
        [ 1.0920],
        [ 0.6885],
        [-0.5932],
        [ 1.0859],
        [-0.5217],
        [-0.1651],
        [-0.2630],
        [ 0.3265],
        [-0.6915],
        [-0.4359],
        [ 0.0079],
        [-0.5314],
        [-0.5768],
        [-0.9093],
        [-0.2034],
        [ 0.3921],
        [-0.6015],
        [-0.1419],
        [-0.4693],
        [-0.5811],
        [-0.0611],
        [-0.2193],
        [-0.1723],
        [ 0.7413],
        [

In [12]:
solver(x)

tensor([[2.1737e-05, 1.2643e-07, 1.2497e-11, 9.3249e-01, 1.0302e-02, 1.7257e-11,
         1.2375e-09, 1.3433e-38, 5.7184e-02, 1.3740e-31]],
       grad_fn=<SoftmaxBackward0>)

In [13]:
z = solver(x)

In [14]:
z.sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [15]:
z.max()

tensor(0.9325, grad_fn=<MaxBackward1>)

In [16]:
trainer.test(learner, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 313/313 [00:05<00:00, 59.45it/s]


[{'test_loss': 1.7803601026535034,
  'test_acc': 0.680400013923645,
  'end_lr': 0.10000026226043701}]