In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from training_utils import LinearSubspace, symmetric_cross_entropy
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
root = "data/baby"

image_feat = np.load(f"{root}/image_feat.npy")
v_prime_feat = np.load(f"{root}/v_prime_feat.npy")
text_feat = np.load(f"{root}/text_feat.npy")
t_prime_feat = np.load(f"{root}/t_prime_feat.npy")
typographic_feat = np.load(f"{root}/typographic_feat.npy")

class_id = np.arange(len(image_feat))
real_or_fake = np.ones(len(image_feat))  # Dummy values for compatibility

# Convert to float32 if necessary
image_feat = image_feat.astype(np.float32)
v_prime_feat = v_prime_feat.astype(np.float32)
text_feat = text_feat.astype(np.float32)
t_prime_feat = t_prime_feat.astype(np.float32)
typographic_feat = typographic_feat.astype(np.float32)

In [3]:
class FeatureDataset(Dataset):
    def __init__(self, img_i, img_t, img_ti, txt_i, txt_t, class_id, real_or_fake):
        self.img_i = torch.tensor(img_i)
        self.img_t = torch.tensor(img_t)
        self.img_ti = torch.tensor(img_ti)
        self.txt_i = torch.tensor(txt_i)
        self.txt_t = torch.tensor(txt_t)
        self.class_id = torch.tensor(class_id)
        self.real_or_fake = torch.tensor(real_or_fake)

    def __len__(self):
        return len(self.img_i)

    def __getitem__(self, idx):
        return (self.img_i[idx], self.img_t[idx], self.img_ti[idx],
                self.txt_i[self.class_id[idx]], self.txt_t[idx],
                self.class_id[idx], self.real_or_fake[idx])

In [37]:
from training_utils import AverageMeter
import torch.nn.functional as F

# 하이퍼파라미터
D_shared = 64  # Bottleneck dim
out_dim = 64     # 최종 투영 차원
batch_size = 128
num_epochs = 100
initial_lr = 1e-4

# 모델 정의
model = LinearSubspace(out_dim=out_dim, input_size=D_shared).to(device)
v_proj = torch.nn.Linear(4096, D_shared).to(device)
t_proj = torch.nn.Linear(384, D_shared).to(device)

# 옵티마이저 및 스케줄러
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(v_proj.parameters()) + list(t_proj.parameters()),
    lr=initial_lr
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4000, gamma=0.5)

# 데이터로더 구성
dataset = FeatureDataset(image_feat, v_prime_feat, v_prime_feat, text_feat, t_prime_feat, class_id, real_or_fake)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [38]:
class LearnWModel(torch.nn.Module):
    def __init__(self, input_size=64, out_dim=64):
        super().__init__()
        self.W = torch.nn.Linear(input_size, out_dim, bias=False)

    def forward(self, *features):
        return [self.W(x) for x in features]

In [39]:
model = LearnWModel().to(device)
# v_proj, t_proj는 위와 동일
optimizer = torch.optim.Adam(list(model.parameters()) + list(v_proj.parameters()) + list(t_proj.parameters()), lr=1e-4)

In [None]:
for epoch in tqdm(range(num_epochs), desc="Training Learn W"):
    for batch in dataloader:
        img_i, img_t, img_ti, txt_i, txt_t, *_ = [x.to(device) for x in batch]

        img_i = F.normalize(v_proj(img_i), dim=-1)
        img_t = F.normalize(v_proj(img_t), dim=-1)
        img_ti = F.normalize(v_proj(img_ti), dim=-1)
        txt_i = F.normalize(t_proj(txt_i), dim=-1)
        txt_t = F.normalize(t_proj(txt_t), dim=-1)

        vi, vt, vti, ti, tt = model(img_i, img_t, img_ti, txt_i, txt_t)

        L1 = symmetric_cross_entropy(vi, ti, device, 0.07)
        L2 = symmetric_cross_entropy(vt, tt, device, 0.07)
        L3 = symmetric_cross_entropy(vti, ti, device, 0.07)
        L4 = symmetric_cross_entropy(vti, tt, device, 0.07)
        L5 = symmetric_cross_entropy(vti, vt, device, 0.07)
        L6 = symmetric_cross_entropy(vti, vi, device, 0.07)

        loss = (-L1 - L2 - L6 + L3 + L4 + L5)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")
    print(f"Epoch {epoch+1}, L1={L1.item():.4f}, L2={L2.item():.4f}, L3={L3.item():.4f}, "
      f"L4={L4.item():.4f}, L5={L5.item():.4f}, L6={L6.item():.4f}, Loss={loss.item():.4f}")

Training Learn W:   0%|          | 0/100 [00:00<?, ?it/s]

Training Learn W:   1%|          | 1/100 [00:01<02:05,  1.27s/it]

Epoch 1/100, Loss: -0.0788
Epoch 1, L1=2.3100, L2=2.2985, L3=2.3031, L4=2.2985, L5=2.2328, L6=2.3046, Loss=-0.0788


Training Learn W:   2%|▏         | 2/100 [00:01<01:31,  1.07it/s]

Epoch 2/100, Loss: -0.1389
Epoch 2, L1=2.3413, L2=2.2867, L3=2.3152, L4=2.2867, L5=2.1885, L6=2.3013, Loss=-0.1389


Training Learn W:   3%|▎         | 3/100 [00:02<01:12,  1.34it/s]

Epoch 3/100, Loss: -0.2217
Epoch 3, L1=2.3673, L2=2.3027, L3=2.3005, L4=2.3027, L5=2.1791, L6=2.3340, Loss=-0.2217


Training Learn W:   4%|▍         | 4/100 [00:02<00:58,  1.63it/s]

Epoch 4/100, Loss: -0.5457
Epoch 4, L1=2.3889, L2=2.4260, L3=2.2483, L4=2.4260, L5=2.0138, L6=2.4189, Loss=-0.5457


Training Learn W:   5%|▌         | 5/100 [00:03<00:49,  1.90it/s]

Epoch 5/100, Loss: -1.0180
Epoch 5, L1=2.5964, L2=2.4830, L3=2.2553, L4=2.4830, L5=1.9180, L6=2.5949, Loss=-1.0180


Training Learn W:   6%|▌         | 6/100 [00:03<00:45,  2.06it/s]

Epoch 6/100, Loss: -0.9214
Epoch 6, L1=2.8545, L2=2.4612, L3=2.5809, L4=2.4612, L5=1.8354, L6=2.4832, Loss=-0.9214


Training Learn W:   7%|▋         | 7/100 [00:04<00:41,  2.24it/s]

Epoch 7/100, Loss: -1.2314
Epoch 7, L1=2.6384, L2=2.9039, L3=2.6857, L4=2.9039, L5=1.7721, L6=3.0508, Loss=-1.2314


Training Learn W:   8%|▊         | 8/100 [00:04<00:39,  2.35it/s]

Epoch 8/100, Loss: -2.7065
Epoch 8, L1=4.0659, L2=2.4681, L3=3.0605, L4=2.4681, L5=1.6792, L6=3.3802, Loss=-2.7065


Training Learn W:   9%|▉         | 9/100 [00:04<00:37,  2.40it/s]

Epoch 9/100, Loss: -2.4562
Epoch 9, L1=4.2511, L2=2.8259, L3=3.4424, L4=2.8259, L5=1.6496, L6=3.2972, Loss=-2.4562


Training Learn W:  10%|█         | 10/100 [00:05<00:42,  2.14it/s]

Epoch 10/100, Loss: -5.2810
Epoch 10, L1=5.6395, L2=3.6647, L3=3.0003, L4=3.6647, L5=1.6365, L6=4.2783, Loss=-5.2810


Training Learn W:  11%|█         | 11/100 [00:05<00:40,  2.20it/s]

Epoch 11/100, Loss: -3.2073
Epoch 11, L1=4.7144, L2=5.6488, L3=4.7722, L4=5.6488, L5=1.5936, L6=4.8586, Loss=-3.2073


Training Learn W:  12%|█▏        | 12/100 [00:06<00:38,  2.31it/s]

Epoch 12/100, Loss: -3.8582
Epoch 12, L1=5.7151, L2=5.1265, L3=5.3881, L4=5.1265, L5=1.6362, L6=5.1674, Loss=-3.8582


Training Learn W:  13%|█▎        | 13/100 [00:06<00:35,  2.44it/s]

Epoch 13/100, Loss: -11.1938
Epoch 13, L1=8.8074, L2=4.8353, L3=3.8286, L4=4.8353, L5=1.6320, L6=7.8470, Loss=-11.1938


Training Learn W:  14%|█▍        | 14/100 [00:06<00:32,  2.62it/s]

Epoch 14/100, Loss: -11.6071
Epoch 14, L1=8.8719, L2=4.2013, L3=4.1686, L4=4.2013, L5=1.6176, L6=8.5214, Loss=-11.6071


Training Learn W:  15%|█▌        | 15/100 [00:07<00:35,  2.40it/s]

Epoch 15/100, Loss: -10.4691
Epoch 15, L1=11.1340, L2=6.0534, L3=6.6983, L4=6.0534, L5=1.9777, L6=8.0111, Loss=-10.4691


Training Learn W:  16%|█▌        | 16/100 [00:07<00:35,  2.33it/s]

Epoch 16/100, Loss: -3.8150
Epoch 16, L1=10.0446, L2=8.5413, L3=9.7872, L4=8.5413, L5=1.7880, L6=5.3455, Loss=-3.8150


Training Learn W:  17%|█▋        | 17/100 [00:08<00:35,  2.35it/s]

Epoch 17/100, Loss: -17.6627
Epoch 17, L1=12.8893, L2=8.7703, L3=4.5994, L4=8.7703, L5=1.7113, L6=11.0841, Loss=-17.6627


Training Learn W:  18%|█▊        | 18/100 [00:08<00:35,  2.34it/s]

Epoch 18/100, Loss: -13.6654
Epoch 18, L1=15.9010, L2=7.9685, L3=9.4852, L4=7.9685, L5=1.9045, L6=9.1542, Loss=-13.6654


Training Learn W:  19%|█▉        | 19/100 [00:09<00:49,  1.63it/s]

Epoch 19/100, Loss: -22.7072
Epoch 19, L1=17.6802, L2=10.3909, L3=7.2360, L4=10.3909, L5=1.6939, L6=13.9569, Loss=-22.7072


Training Learn W:  20%|██        | 20/100 [00:10<00:57,  1.38it/s]

Epoch 20/100, Loss: -22.5191
Epoch 20, L1=15.1068, L2=8.8292, L3=5.1449, L4=8.8292, L5=1.9887, L6=14.5459, Loss=-22.5191


Training Learn W:  21%|██        | 21/100 [00:11<01:09,  1.14it/s]

Epoch 21/100, Loss: -20.0474
Epoch 21, L1=15.6394, L2=5.5853, L3=3.4715, L4=5.5853, L5=2.3027, L6=10.1821, Loss=-20.0474


Training Learn W:  22%|██▏       | 22/100 [00:13<01:14,  1.05it/s]

Epoch 22/100, Loss: -4.6502
Epoch 22, L1=7.6113, L2=2.3065, L3=2.9800, L4=2.3065, L5=2.3069, L6=2.3257, Loss=-4.6502


Training Learn W:  23%|██▎       | 23/100 [00:13<01:03,  1.21it/s]

Epoch 23/100, Loss: 0.0176
Epoch 23, L1=2.3031, L2=2.3210, L3=2.3159, L4=2.3210, L5=2.3240, L6=2.3192, Loss=0.0176


Training Learn W:  24%|██▍       | 24/100 [00:14<00:54,  1.39it/s]

Epoch 24/100, Loss: -30.9130
Epoch 24, L1=14.2077, L2=4.9185, L3=4.1178, L4=4.9185, L5=1.9787, L6=22.8018, Loss=-30.9130


Training Learn W:  25%|██▌       | 25/100 [00:14<00:46,  1.62it/s]

Epoch 25/100, Loss: -16.4521
Epoch 25, L1=7.5128, L2=12.2706, L3=4.6897, L4=12.2706, L5=1.8045, L6=15.4335, Loss=-16.4521


Training Learn W:  26%|██▌       | 26/100 [00:15<00:46,  1.59it/s]

Epoch 26/100, Loss: -2.8189
Epoch 26, L1=5.5932, L2=2.3138, L3=2.7779, L4=2.3138, L5=2.3026, L6=2.3062, Loss=-2.8189


Training Learn W:  27%|██▋       | 27/100 [00:15<00:41,  1.75it/s]

Epoch 27/100, Loss: -0.0031
Epoch 27, L1=2.3050, L2=2.3081, L3=2.3035, L4=2.3081, L5=2.3032, L6=2.3047, Loss=-0.0031


Training Learn W:  28%|██▊       | 28/100 [00:15<00:37,  1.91it/s]

Epoch 28/100, Loss: -37.5811
Epoch 28, L1=20.8620, L2=2.5622, L3=2.4681, L4=2.5622, L5=2.5000, L6=21.6873, Loss=-37.5811


Training Learn W:  29%|██▉       | 29/100 [00:16<00:34,  2.07it/s]

Epoch 29/100, Loss: -6.2778
Epoch 29, L1=7.5377, L2=5.4121, L3=2.8387, L4=5.4121, L5=2.6833, L6=4.2620, Loss=-6.2778


Training Learn W:  30%|███       | 30/100 [00:16<00:32,  2.17it/s]

Epoch 30/100, Loss: -3.4406
Epoch 30, L1=2.3054, L2=4.2910, L3=2.8879, L4=4.2910, L5=2.8232, L6=6.8464, Loss=-3.4406


Training Learn W:  31%|███       | 31/100 [00:17<00:32,  2.14it/s]

Epoch 31/100, Loss: 0.0001
Epoch 31, L1=2.3366, L2=4.5878, L3=2.3315, L4=4.5878, L5=2.3186, L6=2.3134, Loss=0.0001


Training Learn W:  32%|███▏      | 32/100 [00:17<00:30,  2.20it/s]

Epoch 32/100, Loss: -3.1338
Epoch 32, L1=6.3496, L2=5.0266, L3=3.2192, L4=5.0266, L5=2.3029, L6=2.3062, Loss=-3.1338


Training Learn W:  33%|███▎      | 33/100 [00:18<00:36,  1.83it/s]

Epoch 33/100, Loss: -49.8186
Epoch 33, L1=34.0611, L2=5.3523, L3=3.8191, L4=5.3523, L5=2.3089, L6=21.8855, Loss=-49.8186


Training Learn W:  34%|███▍      | 34/100 [00:18<00:34,  1.89it/s]

Epoch 34/100, Loss: -53.5738
Epoch 34, L1=29.0577, L2=2.3066, L3=2.3051, L4=2.3066, L5=2.3027, L6=29.1238, Loss=-53.5738


Training Learn W:  35%|███▌      | 35/100 [00:19<00:32,  2.01it/s]

Epoch 35/100, Loss: -25.5134
Epoch 35, L1=24.8137, L2=3.9221, L3=7.0611, L4=3.9221, L5=3.4277, L6=11.1885, Loss=-25.5134


Training Learn W:  36%|███▌      | 36/100 [00:19<00:32,  1.99it/s]

Epoch 36/100, Loss: -26.5617
Epoch 36, L1=2.3754, L2=5.7673, L3=5.7428, L4=5.7673, L5=1.9777, L6=31.9068, Loss=-26.5617


Training Learn W:  37%|███▋      | 37/100 [00:20<00:30,  2.04it/s]

Epoch 37/100, Loss: -21.0330
Epoch 37, L1=2.3029, L2=19.4072, L3=9.2624, L4=19.4072, L5=1.8137, L6=29.8062, Loss=-21.0330


Training Learn W:  38%|███▊      | 38/100 [00:20<00:30,  2.06it/s]

Epoch 38/100, Loss: -28.7613
Epoch 38, L1=2.3649, L2=5.9311, L3=5.9917, L4=5.9311, L5=1.9776, L6=34.3657, Loss=-28.7613


Training Learn W:  39%|███▉      | 39/100 [00:21<00:31,  1.93it/s]

Epoch 39/100, Loss: -0.1095
Epoch 39, L1=2.4689, L2=6.1251, L3=2.3664, L4=6.1251, L5=2.3038, L6=2.3108, Loss=-0.1095


Training Learn W:  40%|████      | 40/100 [00:21<00:31,  1.92it/s]

Epoch 40/100, Loss: -28.4145
Epoch 40, L1=2.3048, L2=6.6884, L3=6.8961, L4=6.6884, L5=2.9018, L6=35.9075, Loss=-28.4145


Training Learn W:  41%|████      | 41/100 [00:22<00:28,  2.06it/s]

Epoch 41/100, Loss: -55.2130
Epoch 41, L1=38.7032, L2=6.9902, L3=19.2849, L4=6.9902, L5=2.7108, L6=38.5056, Loss=-55.2130


Training Learn W:  42%|████▏     | 42/100 [00:22<00:26,  2.15it/s]

Epoch 42/100, Loss: -100.7135
Epoch 42, L1=69.7388, L2=2.4616, L3=6.8271, L4=2.4616, L5=2.5836, L6=40.3855, Loss=-100.7135


Training Learn W:  43%|████▎     | 43/100 [00:23<00:24,  2.32it/s]

Epoch 43/100, Loss: -0.0396
Epoch 43, L1=2.3845, L2=2.4091, L3=2.3500, L4=2.4091, L5=2.3032, L6=2.3084, Loss=-0.0396


Training Learn W:  44%|████▍     | 44/100 [00:23<00:23,  2.34it/s]

Epoch 44/100, Loss: -0.5164
Epoch 44, L1=3.0841, L2=6.9379, L3=2.5558, L4=6.9379, L5=2.3306, L6=2.3188, Loss=-0.5164


Training Learn W:  45%|████▌     | 45/100 [00:23<00:23,  2.38it/s]

Epoch 45/100, Loss: 0.0196
Epoch 45, L1=2.3070, L2=2.4982, L3=2.3206, L4=2.4982, L5=2.3307, L6=2.3247, Loss=0.0196


Training Learn W:  46%|████▌     | 46/100 [00:24<00:21,  2.50it/s]

Epoch 46/100, Loss: -0.0278
Epoch 46, L1=2.3390, L2=2.3160, L3=2.3238, L4=2.3160, L5=2.3077, L6=2.3203, Loss=-0.0278


Training Learn W:  47%|████▋     | 47/100 [00:24<00:20,  2.53it/s]

Epoch 47/100, Loss: -74.1438
Epoch 47, L1=40.2067, L2=7.4838, L3=2.5365, L4=7.4838, L5=2.3028, L6=38.7764, Loss=-74.1438


Training Learn W:  48%|████▊     | 48/100 [00:25<00:25,  2.00it/s]

Epoch 48/100, Loss: -0.0030
Epoch 48, L1=2.3089, L2=2.3062, L3=2.3072, L4=2.3062, L5=2.3029, L6=2.3042, Loss=-0.0030


Training Learn W:  49%|████▉     | 49/100 [00:26<00:31,  1.64it/s]

Epoch 49/100, Loss: -155.8787
Epoch 49, L1=52.6822, L2=7.8047, L3=7.2726, L4=7.8047, L5=1.9776, L6=112.4467, Loss=-155.8787


Training Learn W:  50%|█████     | 50/100 [00:26<00:26,  1.87it/s]

Epoch 50/100, Loss: -31.2777
Epoch 50, L1=2.4241, L2=6.2160, L3=6.3108, L4=6.2160, L5=1.9777, L6=37.1421, Loss=-31.2777


Training Learn W:  51%|█████     | 51/100 [00:26<00:24,  2.03it/s]

Epoch 51/100, Loss: -45.1896
Epoch 51, L1=53.9052, L2=5.2265, L3=9.8941, L4=5.2265, L5=3.5681, L6=4.7465, Loss=-45.1896


Training Learn W:  52%|█████▏    | 52/100 [00:27<00:21,  2.20it/s]

Epoch 52/100, Loss: -38.0460
Epoch 52, L1=2.3082, L2=7.1317, L3=7.1371, L4=7.1317, L5=2.0271, L6=44.9019, Loss=-38.0460


Training Learn W:  53%|█████▎    | 53/100 [00:27<00:21,  2.17it/s]

Epoch 53/100, Loss: -0.3152
Epoch 53, L1=2.7480, L2=2.3195, L3=2.4306, L4=2.3195, L5=2.3086, L6=2.3064, Loss=-0.3152


Training Learn W:  54%|█████▍    | 54/100 [00:28<00:21,  2.10it/s]

Epoch 54/100, Loss: -155.8346
Epoch 54, L1=59.8531, L2=7.6824, L3=8.8144, L4=7.6824, L5=1.9908, L6=106.7867, Loss=-155.8346


Training Learn W:  55%|█████▌    | 55/100 [00:28<00:19,  2.27it/s]

Epoch 55/100, Loss: -52.6917
Epoch 55, L1=61.7570, L2=2.5926, L3=9.0724, L4=2.5926, L5=2.3477, L6=2.3548, Loss=-52.6917


Training Learn W:  56%|█████▌    | 56/100 [00:29<00:18,  2.37it/s]

Epoch 56/100, Loss: -0.1290
Epoch 56, L1=2.3762, L2=2.3916, L3=2.3119, L4=2.3916, L5=2.3037, L6=2.3684, Loss=-0.1290


Training Learn W:  57%|█████▋    | 57/100 [00:29<00:17,  2.45it/s]

Epoch 57/100, Loss: -1.2842
Epoch 57, L1=3.8344, L2=2.3153, L3=2.5536, L4=2.3153, L5=2.3028, L6=2.3062, Loss=-1.2842


Training Learn W:  58%|█████▊    | 58/100 [00:29<00:16,  2.55it/s]

Epoch 58/100, Loss: -131.1074
Epoch 58, L1=67.7674, L2=2.3060, L3=2.3108, L4=2.3060, L5=2.3036, L6=67.9544, Loss=-131.1074


Training Learn W:  59%|█████▉    | 59/100 [00:30<00:15,  2.59it/s]

Epoch 59/100, Loss: -1.8965
Epoch 59, L1=2.3042, L2=5.3529, L3=2.7146, L4=5.3529, L5=3.0130, L6=5.3200, Loss=-1.8965


Training Learn W:  60%|██████    | 60/100 [00:30<00:15,  2.62it/s]

Epoch 60/100, Loss: -169.9082
Epoch 60, L1=71.5370, L2=8.6462, L3=8.0746, L4=8.6462, L5=3.8199, L6=110.2656, Loss=-169.9082


Training Learn W:  61%|██████    | 61/100 [00:31<00:19,  1.96it/s]

Epoch 61/100, Loss: -55.7315
Epoch 61, L1=2.3606, L2=9.3533, L3=9.3675, L4=9.3533, L5=1.9836, L6=64.7220, Loss=-55.7315


Training Learn W:  62%|██████▏   | 62/100 [00:32<00:22,  1.66it/s]

Epoch 62/100, Loss: -93.9969
Epoch 62, L1=75.2380, L2=8.4837, L3=23.4965, L4=8.4837, L5=3.6063, L6=45.8618, Loss=-93.9969


Training Learn W:  63%|██████▎   | 63/100 [00:32<00:21,  1.70it/s]

Epoch 63/100, Loss: -50.4337
Epoch 63, L1=2.3105, L2=18.9420, L3=19.0427, L4=18.9420, L5=1.8860, L6=69.0519, Loss=-50.4337


Training Learn W:  64%|██████▍   | 64/100 [00:33<00:21,  1.69it/s]

Epoch 64/100, Loss: -0.6029
Epoch 64, L1=3.0252, L2=2.3186, L3=2.4607, L4=2.3186, L5=2.3213, L6=2.3598, Loss=-0.6029


Training Learn W:  65%|██████▌   | 65/100 [00:33<00:19,  1.79it/s]

Epoch 65/100, Loss: 0.0199
Epoch 65, L1=2.3563, L2=2.5541, L3=2.3595, L4=2.5541, L5=2.3575, L6=2.3407, Loss=0.0199


Training Learn W:  66%|██████▌   | 66/100 [00:34<00:19,  1.77it/s]

Epoch 66/100, Loss: -0.8992
Epoch 66, L1=2.3068, L2=2.5926, L3=2.5687, L4=2.5926, L5=2.7861, L6=3.9472, Loss=-0.8992


Training Learn W:  67%|██████▋   | 67/100 [00:34<00:17,  1.94it/s]

Epoch 67/100, Loss: -1.8471
Epoch 67, L1=2.7540, L2=2.8675, L3=2.8103, L4=2.8675, L5=3.0108, L6=4.9142, Loss=-1.8471


Training Learn W:  68%|██████▊   | 68/100 [00:35<00:16,  1.97it/s]

Epoch 68/100, Loss: -74.9917
Epoch 68, L1=86.7669, L2=2.3715, L3=11.7748, L4=2.3715, L5=2.3088, L6=2.3084, Loss=-74.9917


Training Learn W:  69%|██████▉   | 69/100 [00:35<00:15,  2.01it/s]

Epoch 69/100, Loss: -0.1034
Epoch 69, L1=2.5263, L2=2.3928, L3=2.4303, L4=2.3928, L5=2.3857, L6=2.3931, Loss=-0.1034


Training Learn W:  70%|███████   | 70/100 [00:36<00:15,  1.94it/s]

Epoch 70/100, Loss: -10.1596
Epoch 70, L1=14.2671, L2=2.3180, L3=4.1068, L4=2.3180, L5=2.3047, L6=2.3040, Loss=-10.1596


Training Learn W:  71%|███████   | 71/100 [00:37<00:20,  1.42it/s]

Epoch 71/100, Loss: -0.0191
Epoch 71, L1=2.3530, L2=20.7113, L3=2.3336, L4=20.7113, L5=2.3033, L6=2.3030, Loss=-0.0191


Training Learn W:  72%|███████▏  | 72/100 [00:38<00:20,  1.36it/s]

Epoch 72/100, Loss: -205.3579
Epoch 72, L1=116.6949, L2=2.4656, L3=4.5351, L4=2.4656, L5=2.3029, L6=95.5010, Loss=-205.3579


Training Learn W:  73%|███████▎  | 73/100 [00:39<00:27,  1.02s/it]

Epoch 73/100, Loss: -37.0818
Epoch 73, L1=2.3147, L2=7.6335, L3=7.3508, L4=7.6335, L5=3.2391, L6=45.3570, Loss=-37.0818


Training Learn W:  74%|███████▍  | 74/100 [00:41<00:33,  1.28s/it]

Epoch 74/100, Loss: -188.0386
Epoch 74, L1=96.4486, L2=3.2029, L3=2.5350, L4=3.2029, L5=2.3128, L6=96.4378, Loss=-188.0386


Training Learn W:  75%|███████▌  | 75/100 [00:43<00:31,  1.25s/it]

Epoch 75/100, Loss: -198.8776
Epoch 75, L1=101.7038, L2=13.4937, L3=2.3057, L4=13.4937, L5=2.3033, L6=101.7827, Loss=-198.8776


Training Learn W:  76%|███████▌  | 76/100 [00:43<00:26,  1.11s/it]

Epoch 76/100, Loss: -199.6476
Epoch 76, L1=102.4871, L2=2.6466, L3=2.4269, L4=2.6466, L5=2.3522, L6=101.9396, Loss=-199.6476


Training Learn W:  77%|███████▋  | 77/100 [00:44<00:21,  1.05it/s]

Epoch 77/100, Loss: 0.0425
Epoch 77, L1=2.4100, L2=5.2363, L3=2.4227, L4=5.2363, L5=2.4183, L6=2.3885, Loss=0.0425


Training Learn W:  78%|███████▊  | 78/100 [00:44<00:18,  1.19it/s]

Epoch 78/100, Loss: -0.9717
Epoch 78, L1=3.4848, L2=2.3792, L3=2.5796, L4=2.3792, L5=2.4305, L6=2.4970, Loss=-0.9717


Training Learn W:  79%|███████▉  | 79/100 [00:45<00:16,  1.31it/s]

Epoch 79/100, Loss: -0.5851
Epoch 79, L1=3.0766, L2=2.3047, L3=2.4913, L4=2.3047, L5=2.3039, L6=2.3037, Loss=-0.5851


Training Learn W:  80%|████████  | 80/100 [00:46<00:13,  1.44it/s]

Epoch 80/100, Loss: -220.8540
Epoch 80, L1=112.6032, L2=2.3625, L3=2.3168, L4=2.3625, L5=2.3028, L6=112.8703, Loss=-220.8540


Training Learn W:  81%|████████  | 81/100 [00:46<00:12,  1.56it/s]

Epoch 81/100, Loss: -257.5199
Epoch 81, L1=156.3131, L2=14.9020, L3=7.1388, L4=14.9020, L5=2.3026, L6=110.6483, Loss=-257.5199


Training Learn W:  82%|████████▏ | 82/100 [00:48<00:16,  1.12it/s]

Epoch 82/100, Loss: -76.6714
Epoch 82, L1=104.2439, L2=2.5045, L3=27.9545, L4=2.5045, L5=2.5793, L6=2.9612, Loss=-76.6714


Training Learn W:  83%|████████▎ | 83/100 [00:48<00:12,  1.35it/s]

Epoch 83/100, Loss: -213.3073
Epoch 83, L1=112.3879, L2=2.7900, L3=4.3465, L4=2.7900, L5=3.0550, L6=108.3210, Loss=-213.3073


Training Learn W:  84%|████████▍ | 84/100 [00:48<00:10,  1.60it/s]

Epoch 84/100, Loss: -214.6684
Epoch 84, L1=110.5077, L2=2.3212, L3=2.6083, L4=2.3212, L5=2.3026, L6=109.0716, Loss=-214.6684


Training Learn W:  85%|████████▌ | 85/100 [00:49<00:08,  1.81it/s]

Epoch 85/100, Loss: -319.5594
Epoch 85, L1=212.7630, L2=2.4588, L3=15.2616, L4=2.4588, L5=2.3584, L6=124.4164, Loss=-319.5594


Training Learn W:  86%|████████▌ | 86/100 [00:49<00:06,  2.01it/s]

Epoch 86/100, Loss: -277.5688
Epoch 86, L1=161.0538, L2=2.3066, L3=8.1644, L4=2.3066, L5=2.3032, L6=126.9826, Loss=-277.5688


Training Learn W:  87%|████████▋ | 87/100 [00:49<00:05,  2.27it/s]

Epoch 87/100, Loss: -440.9505
Epoch 87, L1=279.8046, L2=25.3076, L3=8.0518, L4=25.3076, L5=5.4323, L6=174.6299, Loss=-440.9505


Training Learn W:  88%|████████▊ | 88/100 [00:50<00:05,  2.08it/s]

Epoch 88/100, Loss: 0.2066
Epoch 88, L1=2.3488, L2=2.5123, L3=2.5169, L4=2.5123, L5=2.6553, L6=2.6167, Loss=0.2066


Training Learn W:  89%|████████▉ | 89/100 [00:50<00:04,  2.22it/s]

Epoch 89/100, Loss: -2.0812
Epoch 89, L1=4.7287, L2=2.3375, L3=2.6495, L4=2.3375, L5=2.3026, L6=2.3045, Loss=-2.0812


Training Learn W:  90%|█████████ | 90/100 [00:51<00:04,  2.37it/s]

Epoch 90/100, Loss: -0.0010
Epoch 90, L1=2.3088, L2=32.3077, L3=2.3079, L4=32.3077, L5=2.3035, L6=2.3035, Loss=-0.0010


Training Learn W:  91%|█████████ | 91/100 [00:52<00:05,  1.67it/s]

Epoch 91/100, Loss: -120.5423
Epoch 91, L1=2.3182, L2=17.4083, L3=17.4211, L4=17.4083, L5=1.9784, L6=137.6236, Loss=-120.5423


Training Learn W:  92%|█████████▏| 92/100 [00:52<00:05,  1.59it/s]

Epoch 92/100, Loss: -277.7535
Epoch 92, L1=142.0147, L2=18.5764, L3=2.6483, L4=18.5764, L5=2.7422, L6=141.1292, Loss=-277.7535


Training Learn W:  93%|█████████▎| 93/100 [00:53<00:04,  1.51it/s]

Epoch 93/100, Loss: 0.0021
Epoch 93, L1=2.3078, L2=2.3207, L3=2.3087, L4=2.3207, L5=2.3064, L6=2.3052, Loss=0.0021


Training Learn W:  94%|█████████▍| 94/100 [00:54<00:03,  1.57it/s]

Epoch 94/100, Loss: -223.7387
Epoch 94, L1=145.8273, L2=18.6652, L3=66.2944, L4=18.6652, L5=1.9782, L6=146.1840, Loss=-223.7387


Training Learn W:  95%|█████████▌| 95/100 [00:55<00:04,  1.23it/s]

Epoch 95/100, Loss: -97.9097
Epoch 95, L1=2.3546, L2=34.4824, L3=34.5092, L4=34.4824, L5=2.5759, L6=132.6403, Loss=-97.9097


Training Learn W:  96%|█████████▌| 96/100 [00:55<00:02,  1.45it/s]

Epoch 96/100, Loss: -0.0010
Epoch 96, L1=2.3082, L2=2.5430, L3=2.3075, L4=2.5430, L5=2.3033, L6=2.3035, Loss=-0.0010


Training Learn W:  97%|█████████▋| 97/100 [00:56<00:01,  1.61it/s]

Epoch 97/100, Loss: -134.3114
Epoch 97, L1=153.6836, L2=3.9012, L3=19.3685, L4=3.9012, L5=2.3133, L6=2.3096, Loss=-134.3114


Training Learn W:  98%|█████████▊| 98/100 [00:56<00:01,  1.72it/s]

Epoch 98/100, Loss: -0.0217
Epoch 98, L1=2.3379, L2=2.8304, L3=2.4220, L4=2.8304, L5=2.4790, L6=2.5848, Loss=-0.0217


Training Learn W:  99%|█████████▉| 99/100 [00:57<00:00,  1.59it/s]

Epoch 99/100, Loss: -159.0455
Epoch 99, L1=158.2836, L2=6.4189, L3=28.0838, L4=6.4189, L5=7.0080, L6=35.8538, Loss=-159.0455


Training Learn W: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s]

Epoch 100/100, Loss: -3.0374
Epoch 100, L1=2.3169, L2=2.9169, L3=2.9184, L4=2.9169, L5=3.4153, L6=7.0542, Loss=-3.0374





In [41]:
torch.save(model.state_dict(), "./disentangle/learn_W.pth")

In [42]:
class ForgetWModel(torch.nn.Module):
    def __init__(self, input_size=64, out_dim=64):
        super().__init__()
        self.W = torch.nn.Linear(input_size, out_dim, bias=False)

    def forward(self, *features):
        return [self.W(x) for x in features]

In [43]:
model = ForgetWModel().to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(v_proj.parameters()) + list(t_proj.parameters()), lr=1e-4)

In [None]:
for epoch in tqdm(range(num_epochs), desc="Training Forget W"):
    for batch in dataloader:
        img_i, img_t, img_ti, txt_i, txt_t, *_ = [x.to(device) for x in batch]

        img_i = F.normalize(v_proj(img_i), dim=-1)
        img_t = F.normalize(v_proj(img_t), dim=-1)
        img_ti = F.normalize(v_proj(img_ti), dim=-1)
        txt_i = F.normalize(t_proj(txt_i), dim=-1)
        txt_t = F.normalize(t_proj(txt_t), dim=-1)

        vi, vt, vti, ti, tt = model(img_i, img_t, img_ti, txt_i, txt_t)

        L1 = symmetric_cross_entropy(vi, ti, device)
        L2 = symmetric_cross_entropy(vt, tt, device)
        L3 = symmetric_cross_entropy(vti, ti, device)
        L4 = symmetric_cross_entropy(vti, tt, device)
        L5 = symmetric_cross_entropy(vti, vt, device)
        L6 = symmetric_cross_entropy(vti, vi, device)

        loss = (L1 + L2 + L6 - L3 - L4 - L5)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")
    print(f"Epoch {epoch + 1}, L1={L1.item():.4f}, L2={L2.item():.4f}, L3={L3.item():.4f}, L4={L4.item():.4f}, L5={L5.item():.4f}, L6={L6.item():.4f}, Loss={loss.item():.4f}")

Training Forget W:   1%|          | 1/100 [00:00<00:47,  2.09it/s]

Epoch 1/100, Loss: -0.0055
Epoch 1, L1=2.2966, L2=2.3028, L3=2.3025, L4=2.3028, L5=2.3019, L6=2.3024, Loss=-0.0055


Training Forget W:   2%|▏         | 2/100 [00:00<00:46,  2.11it/s]

Epoch 2/100, Loss: -0.0726
Epoch 2, L1=2.2280, L2=2.3032, L3=2.3034, L4=2.3032, L5=2.3018, L6=2.3046, Loss=-0.0726


Training Forget W:   3%|▎         | 3/100 [00:01<00:41,  2.33it/s]

Epoch 3/100, Loss: -0.2763
Epoch 3, L1=2.0294, L2=2.3038, L3=2.3091, L4=2.3038, L5=2.3022, L6=2.3055, Loss=-0.2763


Training Forget W:   4%|▍         | 4/100 [00:01<00:45,  2.09it/s]

Epoch 4/100, Loss: 0.0896
Epoch 4, L1=2.3946, L2=2.3065, L3=2.3081, L4=2.3065, L5=2.3023, L6=2.3055, Loss=0.0896


Training Forget W:   5%|▌         | 5/100 [00:02<00:51,  1.85it/s]

Epoch 5/100, Loss: 0.0150
Epoch 5, L1=2.3229, L2=2.3105, L3=2.3112, L4=2.3105, L5=2.3023, L6=2.3057, Loss=0.0150


Training Forget W:   6%|▌         | 6/100 [00:02<00:46,  2.04it/s]

Epoch 6/100, Loss: -0.6047
Epoch 6, L1=1.7046, L2=2.3376, L3=2.3138, L4=2.3376, L5=2.3013, L6=2.3057, Loss=-0.6047


Training Forget W:   7%|▋         | 7/100 [00:03<00:45,  2.06it/s]

Epoch 7/100, Loss: -0.4002
Epoch 7, L1=2.0257, L2=2.9382, L3=2.6705, L4=2.9382, L5=2.3140, L6=2.5585, Loss=-0.4002


Training Forget W:   8%|▊         | 8/100 [00:04<00:56,  1.64it/s]

Epoch 8/100, Loss: -1.7639
Epoch 8, L1=1.9887, L2=5.0454, L3=3.3374, L4=5.0454, L5=2.7834, L6=2.3682, Loss=-1.7639


Training Forget W:   9%|▉         | 9/100 [00:04<00:53,  1.70it/s]

Epoch 9/100, Loss: -18.8831
Epoch 9, L1=1.9408, L2=17.6202, L3=21.4226, L4=17.6202, L5=1.8955, L6=2.4942, Loss=-18.8831


Training Forget W:  10%|█         | 10/100 [00:05<01:07,  1.32it/s]

Epoch 10/100, Loss: -23.9059
Epoch 10, L1=2.2201, L2=42.0418, L3=26.4324, L4=42.0418, L5=1.9786, L6=2.2849, Loss=-23.9059


Training Forget W:  11%|█         | 11/100 [00:06<00:59,  1.49it/s]

Epoch 11/100, Loss: -28.5026
Epoch 11, L1=2.3082, L2=31.5752, L3=31.1438, L4=31.5752, L5=1.9775, L6=2.3105, Loss=-28.5026


Training Forget W:  12%|█▏        | 12/100 [00:06<00:54,  1.62it/s]

Epoch 12/100, Loss: -41.2046
Epoch 12, L1=2.5846, L2=75.3809, L3=44.2249, L4=75.3809, L5=2.0194, L6=2.4551, Loss=-41.2046


Training Forget W:  13%|█▎        | 13/100 [00:07<00:47,  1.84it/s]

Epoch 13/100, Loss: -37.6337
Epoch 13, L1=2.5255, L2=41.8740, L3=40.5598, L4=41.8740, L5=1.8386, L6=2.2392, Loss=-37.6337


Training Forget W:  14%|█▍        | 14/100 [00:07<00:42,  2.01it/s]

Epoch 14/100, Loss: -2.5524
Epoch 14, L1=2.4373, L2=2.7298, L3=5.4249, L4=2.7298, L5=2.3028, L6=2.7381, Loss=-2.5524


Training Forget W:  15%|█▌        | 15/100 [00:08<00:39,  2.14it/s]

Epoch 15/100, Loss: -5.8167
Epoch 15, L1=2.3206, L2=7.9218, L3=7.2988, L4=7.9218, L5=3.1932, L6=2.3547, Loss=-5.8167


Training Forget W:  16%|█▌        | 16/100 [00:08<00:36,  2.30it/s]

Epoch 16/100, Loss: -63.6955
Epoch 16, L1=2.5281, L2=65.6296, L3=65.1474, L4=65.6296, L5=3.6048, L6=2.5285, Loss=-63.6955


Training Forget W:  17%|█▋        | 17/100 [00:08<00:36,  2.29it/s]

Epoch 17/100, Loss: -0.4406
Epoch 17, L1=2.3504, L2=76.1774, L3=2.9235, L4=76.1774, L5=2.3032, L6=2.4357, Loss=-0.4406


Training Forget W:  18%|█▊        | 18/100 [00:09<00:39,  2.08it/s]

Epoch 18/100, Loss: 0.4294
Epoch 18, L1=2.7683, L2=2.6860, L3=2.5860, L4=2.6860, L5=2.3729, L6=2.6201, Loss=0.4294


Training Forget W:  19%|█▉        | 19/100 [00:09<00:39,  2.05it/s]

Epoch 19/100, Loss: -5.2911
Epoch 19, L1=2.6296, L2=6.2480, L3=6.6009, L4=6.2480, L5=3.9318, L6=2.6120, Loss=-5.2911


Training Forget W:  20%|██        | 20/100 [00:10<00:34,  2.31it/s]

Epoch 20/100, Loss: -18.3441
Epoch 20, L1=2.5052, L2=129.3230, L3=21.1422, L4=129.3230, L5=2.3042, L6=2.5972, Loss=-18.3441


Training Forget W:  21%|██        | 21/100 [00:10<00:32,  2.41it/s]

Epoch 21/100, Loss: -0.3366
Epoch 21, L1=2.5819, L2=136.5769, L3=3.2166, L4=136.5769, L5=2.3033, L6=2.6015, Loss=-0.3366


Training Forget W:  22%|██▏       | 22/100 [00:11<00:32,  2.39it/s]

Epoch 22/100, Loss: 0.7920
Epoch 22, L1=2.8705, L2=2.9199, L3=2.4538, L4=2.9199, L5=2.3033, L6=2.6786, Loss=0.7920


Training Forget W:  23%|██▎       | 23/100 [00:11<00:31,  2.48it/s]

Epoch 23/100, Loss: -6.1521
Epoch 23, L1=2.5646, L2=11.0085, L3=7.8891, L4=11.0085, L5=3.4847, L6=2.6572, Loss=-6.1521


Training Forget W:  24%|██▍       | 24/100 [00:11<00:30,  2.47it/s]

Epoch 24/100, Loss: -0.5900
Epoch 24, L1=2.6199, L2=2.5631, L3=3.5084, L4=2.5631, L5=2.3140, L6=2.6125, Loss=-0.5900


Training Forget W:  25%|██▌       | 25/100 [00:12<00:29,  2.50it/s]

Epoch 25/100, Loss: 0.4179
Epoch 25, L1=2.9077, L2=3.9264, L3=3.0084, L4=3.9264, L5=2.3043, L6=2.8230, Loss=0.4179


Training Forget W:  26%|██▌       | 26/100 [00:12<00:30,  2.44it/s]

Epoch 26/100, Loss: -15.9591
Epoch 26, L1=2.5878, L2=17.0245, L3=16.2767, L4=17.0245, L5=4.9225, L6=2.6522, Loss=-15.9591


Training Forget W:  27%|██▋       | 27/100 [00:13<00:28,  2.52it/s]

Epoch 27/100, Loss: -153.3671
Epoch 27, L1=2.7262, L2=224.2901, L3=155.9213, L4=224.2901, L5=2.9674, L6=2.7953, Loss=-153.3671


Training Forget W:  28%|██▊       | 28/100 [00:13<00:29,  2.46it/s]

Epoch 28/100, Loss: -233.5276
Epoch 28, L1=3.6226, L2=2.5259, L3=237.6702, L4=2.5259, L5=2.3238, L6=2.8438, Loss=-233.5276


Training Forget W:  29%|██▉       | 29/100 [00:14<00:34,  2.06it/s]

Epoch 29/100, Loss: -1.7086
Epoch 29, L1=3.5773, L2=3.4387, L3=6.1450, L4=3.4387, L5=2.3040, L6=3.1631, Loss=-1.7086


Training Forget W:  30%|███       | 30/100 [00:15<00:43,  1.59it/s]

Epoch 30/100, Loss: -470.4464
Epoch 30, L1=3.8768, L2=268.1640, L3=475.7964, L4=268.1640, L5=1.9804, L6=3.4536, Loss=-470.4464


Training Forget W:  31%|███       | 31/100 [00:15<00:39,  1.75it/s]

Epoch 31/100, Loss: 1.2980
Epoch 31, L1=3.0338, L2=3.0409, L3=2.6185, L4=3.0409, L5=2.3095, L6=3.1923, Loss=1.2980


Training Forget W:  32%|███▏      | 32/100 [00:16<00:36,  1.85it/s]

Epoch 32/100, Loss: -1.8371
Epoch 32, L1=3.3747, L2=5.5755, L3=5.6790, L4=5.5755, L5=3.0315, L6=3.4986, Loss=-1.8371


Training Forget W:  33%|███▎      | 33/100 [00:16<00:34,  1.94it/s]

Epoch 33/100, Loss: 1.3790
Epoch 33, L1=3.1691, L2=2.4713, L3=2.5204, L4=2.4713, L5=2.3593, L6=3.0896, Loss=1.3790


Training Forget W:  34%|███▍      | 34/100 [00:16<00:30,  2.15it/s]

Epoch 34/100, Loss: 1.2403
Epoch 34, L1=3.1198, L2=332.6526, L3=2.7684, L4=332.6526, L5=2.3031, L6=3.1921, Loss=1.2403


Training Forget W:  35%|███▌      | 35/100 [00:17<00:27,  2.36it/s]

Epoch 35/100, Loss: -5.7887
Epoch 35, L1=2.8704, L2=6.6786, L3=7.8811, L4=6.6786, L5=3.3817, L6=2.6037, Loss=-5.7887


Training Forget W:  36%|███▌      | 36/100 [00:17<00:25,  2.50it/s]

Epoch 36/100, Loss: -364.6786
Epoch 36, L1=4.1118, L2=371.6395, L3=370.9023, L4=371.6395, L5=1.9906, L6=4.1026, Loss=-364.6786


Training Forget W:  37%|███▋      | 37/100 [00:17<00:24,  2.54it/s]

Epoch 37/100, Loss: -2.7370
Epoch 37, L1=4.0737, L2=10.6524, L3=6.8124, L4=10.6524, L5=3.3336, L6=3.3354, Loss=-2.7370


Training Forget W:  38%|███▊      | 38/100 [00:18<00:24,  2.55it/s]

Epoch 38/100, Loss: -391.2608
Epoch 38, L1=3.7356, L2=2.3966, L3=395.7758, L4=2.3966, L5=2.3046, L6=3.0839, Loss=-391.2608


Training Forget W:  39%|███▉      | 39/100 [00:18<00:22,  2.72it/s]

Epoch 39/100, Loss: -140.3386
Epoch 39, L1=2.9205, L2=133.9701, L3=133.1775, L4=133.9701, L5=13.3478, L6=3.2662, Loss=-140.3386


Training Forget W:  40%|████      | 40/100 [00:18<00:20,  2.89it/s]

Epoch 40/100, Loss: -17.9357
Epoch 40, L1=3.0174, L2=14.7223, L3=19.1575, L4=14.7223, L5=4.9457, L6=3.1502, Loss=-17.9357


Training Forget W:  41%|████      | 41/100 [00:19<00:20,  2.93it/s]

Epoch 41/100, Loss: 0.5867
Epoch 41, L1=4.0562, L2=2.3922, L3=4.7249, L4=2.3922, L5=2.3036, L6=3.5590, Loss=0.5867


Training Forget W:  42%|████▏     | 42/100 [00:19<00:20,  2.79it/s]

Epoch 42/100, Loss: -0.2617
Epoch 42, L1=4.0199, L2=4.4125, L3=4.8270, L4=4.4125, L5=2.7673, L6=3.3127, Loss=-0.2617


Training Forget W:  43%|████▎     | 43/100 [00:19<00:20,  2.78it/s]

Epoch 43/100, Loss: 1.9377
Epoch 43, L1=3.7556, L2=3.0383, L3=2.6667, L4=3.0383, L5=2.4243, L6=3.2731, Loss=1.9377


Training Forget W:  44%|████▍     | 44/100 [00:20<00:21,  2.66it/s]

Epoch 44/100, Loss: -518.6434
Epoch 44, L1=6.2670, L2=198.4072, L3=526.4255, L4=198.4072, L5=2.6112, L6=4.1264, Loss=-518.6434


Training Forget W:  45%|████▌     | 45/100 [00:20<00:21,  2.57it/s]

Epoch 45/100, Loss: -114.1477
Epoch 45, L1=3.2845, L2=104.5049, L3=104.3129, L4=104.5049, L5=16.4499, L6=3.3307, Loss=-114.1477


Training Forget W:  46%|████▌     | 46/100 [00:21<00:19,  2.76it/s]

Epoch 46/100, Loss: -864.5118
Epoch 46, L1=4.3063, L2=567.7132, L3=871.8423, L4=567.7132, L5=1.9810, L6=5.0052, Loss=-864.5118


Training Forget W:  47%|████▋     | 47/100 [00:21<00:19,  2.74it/s]

Epoch 47/100, Loss: -580.0427
Epoch 47, L1=5.4494, L2=3.2285, L3=587.5236, L4=3.2285, L5=2.4596, L6=4.4911, Loss=-580.0427


Training Forget W:  48%|████▊     | 48/100 [00:22<00:23,  2.20it/s]

Epoch 48/100, Loss: -537.2321
Epoch 48, L1=3.0937, L2=608.6031, L3=541.7462, L4=608.6031, L5=2.3038, L6=3.7243, Loss=-537.2321


Training Forget W:  49%|████▉     | 49/100 [00:22<00:21,  2.33it/s]

Epoch 49/100, Loss: 2.7526
Epoch 49, L1=4.5390, L2=629.4252, L3=3.6759, L4=629.4252, L5=2.3068, L6=4.1962, Loss=2.7526


Training Forget W:  50%|█████     | 50/100 [00:22<00:21,  2.29it/s]

Epoch 50/100, Loss: -644.8588
Epoch 50, L1=4.2828, L2=652.8253, L3=651.7540, L4=652.8253, L5=1.9813, L6=4.5936, Loss=-644.8588


Training Forget W:  51%|█████     | 51/100 [00:23<00:19,  2.49it/s]

Epoch 51/100, Loss: -616.6545
Epoch 51, L1=4.5331, L2=625.3165, L3=624.6680, L4=625.3165, L5=1.9794, L6=5.4597, Loss=-616.6545


Training Forget W:  52%|█████▏    | 52/100 [00:23<00:20,  2.34it/s]

Epoch 52/100, Loss: 2.0215
Epoch 52, L1=5.1423, L2=2.8318, L3=5.8164, L4=2.8318, L5=2.3093, L6=5.0049, Loss=2.0215


Training Forget W:  53%|█████▎    | 53/100 [00:24<00:22,  2.09it/s]

Epoch 53/100, Loss: -56.9278
Epoch 53, L1=3.7273, L2=57.4496, L3=52.9596, L4=57.4496, L5=14.1346, L6=6.4391, Loss=-56.9278


Training Forget W:  54%|█████▍    | 54/100 [00:24<00:20,  2.29it/s]

Epoch 54/100, Loss: -737.8113
Epoch 54, L1=5.4842, L2=2.4435, L3=745.5776, L4=2.4435, L5=2.3032, L6=4.5854, Loss=-737.8113


Training Forget W:  55%|█████▌    | 55/100 [00:25<00:19,  2.30it/s]

Epoch 55/100, Loss: -762.2717
Epoch 55, L1=4.5054, L2=1354.2932, L3=769.0945, L4=1354.2932, L5=3.4193, L6=5.7366, Loss=-762.2717


Training Forget W:  56%|█████▌    | 56/100 [00:25<00:22,  1.96it/s]

Epoch 56/100, Loss: -236.0375
Epoch 56, L1=5.4837, L2=766.7562, L3=209.6572, L4=766.7562, L5=36.4619, L6=4.5979, Loss=-236.0375


Training Forget W:  57%|█████▋    | 57/100 [00:26<00:23,  1.82it/s]

Epoch 57/100, Loss: 4.5407
Epoch 57, L1=4.9712, L2=25.8539, L3=3.0477, L4=25.8539, L5=2.3036, L6=4.9208, Loss=4.5407


Training Forget W:  58%|█████▊    | 58/100 [00:26<00:20,  2.01it/s]

Epoch 58/100, Loss: 3.5022
Epoch 58, L1=4.6000, L2=6.3769, L3=3.5951, L4=6.3769, L5=2.5385, L6=5.0358, Loss=3.5022


Training Forget W:  59%|█████▉    | 59/100 [00:27<00:18,  2.16it/s]

Epoch 59/100, Loss: -859.5583
Epoch 59, L1=4.5898, L2=15.5100, L3=865.2826, L4=15.5100, L5=2.3187, L6=3.4532, Loss=-859.5583


Training Forget W:  60%|██████    | 60/100 [00:27<00:17,  2.22it/s]

Epoch 60/100, Loss: -880.0397
Epoch 60, L1=9.7001, L2=3.6535, L3=892.0602, L4=3.6535, L5=2.3111, L6=4.6315, Loss=-880.0397


Training Forget W:  61%|██████    | 61/100 [00:28<00:17,  2.21it/s]

Epoch 61/100, Loss: -805.1633
Epoch 61, L1=5.9516, L2=814.7412, L3=816.6360, L4=814.7412, L5=2.3125, L6=7.8337, Loss=-805.1633


Training Forget W:  62%|██████▏   | 62/100 [00:28<00:17,  2.13it/s]

Epoch 62/100, Loss: 0.9069
Epoch 62, L1=6.1891, L2=837.8569, L3=7.0009, L4=837.8569, L5=3.9054, L6=5.6241, Loss=0.9069


Training Forget W:  63%|██████▎   | 63/100 [00:29<00:17,  2.13it/s]

Epoch 63/100, Loss: 5.7356
Epoch 63, L1=5.3512, L2=4.3142, L3=4.4175, L4=4.3142, L5=2.6466, L6=7.4486, Loss=5.7356


Training Forget W:  64%|██████▍   | 64/100 [00:29<00:15,  2.28it/s]

Epoch 64/100, Loss: -984.5528
Epoch 64, L1=6.2117, L2=994.6866, L3=995.7063, L4=994.6866, L5=2.1376, L6=7.0794, Loss=-984.5528


Training Forget W:  65%|██████▌   | 65/100 [00:29<00:13,  2.52it/s]

Epoch 65/100, Loss: -964.8658
Epoch 65, L1=6.0488, L2=974.3838, L3=975.1928, L4=974.3838, L5=1.9840, L6=6.2622, Loss=-964.8658


Training Forget W:  66%|██████▌   | 66/100 [00:30<00:12,  2.71it/s]

Epoch 66/100, Loss: -1032.7345
Epoch 66, L1=6.1485, L2=1043.6875, L3=1045.4819, L4=1043.6875, L5=3.1632, L6=9.7621, Loss=-1032.7345


Training Forget W:  67%|██████▋   | 67/100 [00:30<00:12,  2.72it/s]

Epoch 67/100, Loss: -7.8359
Epoch 67, L1=5.3629, L2=3.1180, L3=16.1427, L4=3.1180, L5=2.3108, L6=5.2547, Loss=-7.8359


Training Forget W:  68%|██████▊   | 68/100 [00:30<00:12,  2.54it/s]

Epoch 68/100, Loss: -1087.9584
Epoch 68, L1=7.3528, L2=1118.8708, L3=1101.7086, L4=1118.8708, L5=3.0561, L6=9.4535, Loss=-1087.9584


Training Forget W:  69%|██████▉   | 69/100 [00:31<00:12,  2.44it/s]

Epoch 69/100, Loss: -995.9267
Epoch 69, L1=6.3610, L2=3.2629, L3=1005.6277, L4=3.2629, L5=2.3150, L6=5.6551, Loss=-995.9267


Training Forget W:  70%|███████   | 70/100 [00:31<00:12,  2.49it/s]

Epoch 70/100, Loss: 3.8447
Epoch 70, L1=4.5317, L2=2.6497, L3=2.6631, L4=2.6497, L5=2.3090, L6=4.2851, Loss=3.8447


Training Forget W:  71%|███████   | 71/100 [00:31<00:10,  2.79it/s]

Epoch 71/100, Loss: 5.0167
Epoch 71, L1=6.4711, L2=6.2263, L3=5.2901, L4=6.2263, L5=2.5076, L6=6.3433, Loss=5.0167


Training Forget W:  72%|███████▏  | 72/100 [00:32<00:09,  2.95it/s]

Epoch 72/100, Loss: -1208.7141
Epoch 72, L1=9.9174, L2=10.5403, L3=1220.2467, L4=10.5403, L5=4.6236, L6=6.2388, Loss=-1208.7141


Training Forget W:  73%|███████▎  | 73/100 [00:32<00:09,  2.93it/s]

Epoch 73/100, Loss: -204.6584
Epoch 73, L1=6.1686, L2=2.7337, L3=214.1749, L4=2.7337, L5=2.3090, L6=5.6569, Loss=-204.6584


Training Forget W:  74%|███████▍  | 74/100 [00:32<00:08,  2.92it/s]

Epoch 74/100, Loss: 1.9210
Epoch 74, L1=6.8403, L2=2.6368, L3=9.7022, L4=2.6368, L5=2.3060, L6=7.0890, Loss=1.9210


Training Forget W:  75%|███████▌  | 75/100 [00:33<00:09,  2.61it/s]

Epoch 75/100, Loss: -1298.4799
Epoch 75, L1=3.7472, L2=2313.1357, L3=1305.2303, L4=2313.1357, L5=1.9807, L6=4.9840, Loss=-1298.4799


Training Forget W:  76%|███████▌  | 76/100 [00:34<00:12,  1.90it/s]

Epoch 76/100, Loss: -1326.8490
Epoch 76, L1=4.1330, L2=1333.9468, L3=1334.6804, L4=1333.9468, L5=4.6073, L6=8.3056, Loss=-1326.8490


Training Forget W:  77%|███████▋  | 77/100 [00:34<00:11,  1.94it/s]

Epoch 77/100, Loss: -1213.5626
Epoch 77, L1=5.6362, L2=1213.8258, L3=1225.7129, L4=1213.8258, L5=2.0007, L6=8.5148, Loss=-1213.5626


Training Forget W:  78%|███████▊  | 78/100 [00:35<00:11,  1.90it/s]

Epoch 78/100, Loss: 7.0437
Epoch 78, L1=7.1042, L2=1394.4084, L3=3.6516, L4=1394.4084, L5=2.5185, L6=6.1095, Loss=7.0437


Training Forget W:  79%|███████▉  | 79/100 [00:35<00:10,  2.08it/s]

Epoch 79/100, Loss: 6.0339
Epoch 79, L1=4.2385, L2=1261.2889, L3=3.4657, L4=1261.2889, L5=2.3103, L6=7.5715, Loss=6.0339


Training Forget W:  80%|████████  | 80/100 [00:36<00:08,  2.23it/s]

Epoch 80/100, Loss: 2.1244
Epoch 80, L1=7.4784, L2=5.8743, L3=8.2679, L4=5.8743, L5=2.6846, L6=5.5985, Loss=2.1244


Training Forget W:  81%|████████  | 81/100 [00:36<00:07,  2.47it/s]

Epoch 81/100, Loss: 7.1418
Epoch 81, L1=9.0595, L2=4.4028, L3=4.6763, L4=4.4028, L5=2.7031, L6=5.4616, Loss=7.1418


Training Forget W:  82%|████████▏ | 82/100 [00:36<00:07,  2.52it/s]

Epoch 82/100, Loss: -144.3012
Epoch 82, L1=6.8423, L2=124.6730, L3=133.3122, L4=124.6730, L5=25.3318, L6=7.5006, Loss=-144.3012


Training Forget W:  83%|████████▎ | 83/100 [00:37<00:06,  2.45it/s]

Epoch 83/100, Loss: -25.0884
Epoch 83, L1=9.1311, L2=867.8654, L3=32.9030, L4=867.8654, L5=7.9056, L6=6.5892, Loss=-25.0884


Training Forget W:  84%|████████▍ | 84/100 [00:37<00:06,  2.48it/s]

Epoch 84/100, Loss: -1453.3308
Epoch 84, L1=6.9682, L2=1435.4093, L3=1432.5177, L4=1435.4093, L5=39.8488, L6=12.0674, Loss=-1453.3308


Training Forget W:  85%|████████▌ | 85/100 [00:37<00:05,  2.60it/s]

Epoch 85/100, Loss: -1593.4114
Epoch 85, L1=7.5576, L2=1611.6871, L3=1612.1189, L4=1611.6871, L5=1.9811, L6=13.1309, Loss=-1593.4114


Training Forget W:  86%|████████▌ | 86/100 [00:38<00:05,  2.71it/s]

Epoch 86/100, Loss: -1621.2633
Epoch 86, L1=10.2225, L2=1668.6815, L3=1640.6445, L4=1668.6815, L5=2.9257, L6=12.0843, Loss=-1621.2633


Training Forget W:  87%|████████▋ | 87/100 [00:38<00:04,  2.90it/s]

Epoch 87/100, Loss: 10.9913
Epoch 87, L1=9.4968, L2=2.3665, L3=5.8954, L4=2.3665, L5=2.3055, L6=9.6954, Loss=10.9913


Training Forget W:  88%|████████▊ | 88/100 [00:38<00:04,  2.83it/s]

Epoch 88/100, Loss: 10.3790
Epoch 88, L1=9.1931, L2=7.0567, L3=4.4023, L4=7.0567, L5=2.8519, L6=8.4401, Loss=10.3790


Training Forget W:  89%|████████▉ | 89/100 [00:39<00:04,  2.71it/s]

Epoch 89/100, Loss: -2394.3489
Epoch 89, L1=10.8319, L2=1420.3405, L3=2358.0090, L4=1420.3405, L5=55.5466, L6=8.3749, Loss=-2394.3489


Training Forget W:  90%|█████████ | 90/100 [00:39<00:03,  2.73it/s]

Epoch 90/100, Loss: 4.3633
Epoch 90, L1=5.7347, L2=1774.2498, L3=4.8975, L4=1774.2498, L5=2.3110, L6=5.8371, Loss=4.3633


Training Forget W:  91%|█████████ | 91/100 [00:40<00:03,  2.77it/s]

Epoch 91/100, Loss: 6.8523
Epoch 91, L1=7.7229, L2=11.6741, L3=5.9365, L4=11.6741, L5=2.4113, L6=7.4772, Loss=6.8523


Training Forget W:  92%|█████████▏| 92/100 [00:40<00:02,  2.72it/s]

Epoch 92/100, Loss: 8.8712
Epoch 92, L1=6.9478, L2=10.9365, L3=2.7845, L4=10.9365, L5=2.3148, L6=7.0227, Loss=8.8712


Training Forget W:  93%|█████████▎| 93/100 [00:40<00:02,  2.88it/s]

Epoch 93/100, Loss: -1123.2388
Epoch 93, L1=7.9820, L2=2656.9038, L3=1140.4376, L4=2656.9038, L5=1.9913, L6=11.2082, Loss=-1123.2388


Training Forget W:  94%|█████████▍| 94/100 [00:41<00:02,  2.91it/s]

Epoch 94/100, Loss: 7.5920
Epoch 94, L1=9.5919, L2=4.2402, L3=5.2912, L4=4.2402, L5=2.3092, L6=5.6006, Loss=7.5920


Training Forget W:  95%|█████████▌| 95/100 [00:41<00:01,  3.00it/s]

Epoch 95/100, Loss: -8.6985
Epoch 95, L1=9.8077, L2=1786.2628, L3=23.0233, L4=1786.2628, L5=3.3983, L6=7.9154, Loss=-8.6985


Training Forget W:  96%|█████████▌| 96/100 [00:41<00:01,  2.66it/s]

Epoch 96/100, Loss: -1963.8224
Epoch 96, L1=11.8657, L2=1758.8044, L3=1979.8074, L4=1758.8044, L5=2.5741, L6=6.6933, Loss=-1963.8224


Training Forget W:  97%|█████████▋| 97/100 [00:42<00:01,  2.22it/s]

Epoch 97/100, Loss: -28.9851
Epoch 97, L1=8.4018, L2=2028.9266, L3=37.5554, L4=2028.9266, L5=7.8000, L6=7.9685, Loss=-28.9851


Training Forget W:  98%|█████████▊| 98/100 [00:42<00:00,  2.23it/s]

Epoch 98/100, Loss: -2015.8403
Epoch 98, L1=6.8997, L2=16.4761, L3=2026.1582, L4=16.4761, L5=7.0775, L6=10.4957, Loss=-2015.8403


Training Forget W:  99%|█████████▉| 99/100 [00:43<00:00,  2.39it/s]

Epoch 99/100, Loss: -1843.6803
Epoch 99, L1=7.2027, L2=1850.6130, L3=1852.4027, L4=1850.6130, L5=7.0041, L6=8.5239, Loss=-1843.6803


Training Forget W: 100%|██████████| 100/100 [00:43<00:00,  2.30it/s]

Epoch 100/100, Loss: -2017.2921
Epoch 100, L1=12.4930, L2=77.7202, L3=2035.1296, L4=77.7202, L5=3.1650, L6=8.5095, Loss=-2017.2921





In [45]:
torch.save(model.state_dict(), "./disentangle/forget_W.pth")