In [1]:
import sys
sys.path.append("..")

In [3]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from hrdae.models.networks import create_network, RDAE2dOption
from hrdae.models.networks.motion_encoder import MotionRNNEncoder1dOption
from hrdae.models.networks.rnn import TCN1dOption
from hrdae.dataloaders.datasets import create_dataset, MovingMNISTDatasetOption
from hrdae.dataloaders.transforms import create_transform, MinMaxNormalizationOption


In [13]:
net = create_network(
    1,
    opt=RDAE2dOption(
        activation="sigmoid",
        aggregator="addition",
        cycle=False,
        in_channels=1,
        hidden_channels=64,
        latent_dim=8,
        conv_params=[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
        motion_encoder=MotionRNNEncoder1dOption(
            in_channels=5,
            hidden_channels=64,
            conv_params=[{"kernel_size": [3], "stride": [2], "padding": [1]}] * 3,
            deconv_params=[{"kernel_size": [3], "stride": [1, 2], "padding": [1]}] * 3,
            rnn=TCN1dOption(
                num_layers=3,
                image_size=8,
                kernel_size=4,
                dropout=0.1,
            )
        )
    )
)

In [14]:
net.load_state_dict(torch.load("../results/BasicDataLoaderOption/PVRModelOption/rdae2d/2024-06-27_21-27-00/weights/best_model.pth"))

<All keys matched successfully>

In [15]:
net.eval()

RDAE2d(
  (content_encoder): Encoder2d(
    (cnn): ConvModule2d(
      (layers): ModuleList(
        (0): Sequential(
          (0): ConvBlock2d(
            (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          )
          (1): IdenticalConvBlock2d(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (1-2): 2 x Sequential(
          (0): ConvBlock2d(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          )
          (1): IdenticalConvBlock2d(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (3): ConvBlock2d(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (bottleneck): PixelWiseConv2d(
      (conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (motion_encoder): MotionRNNEncoder1d(
    (cnn): ConvModule1d(
   

In [16]:
dataset = create_dataset(
    opt=MovingMNISTDatasetOption(
        sequential=True,
        root="../data",
        slice_index=[16, 24, 32, 40, 48],
        content_phase="0",
        motion_phase="0",
        motion_aggregation="none",
    ),
    transform=transforms.Compose([
        create_transform(MinMaxNormalizationOption()),
    ]),
    is_train=False,
)

In [17]:
loader = DataLoader(dataset=dataset, batch_size=10, shuffle=False)

In [52]:
for data in loader:
    xm = data["xm"]
    xp = data["xp"]
    ys, latents = [], []
    for i in range(10):
        y, latent = net(xm, xp[:, i], xm[:, i])
        ys.append(y)
        latents.append(latent[0])
    for i in range(10):
        for j in range(10):
            if i != j:
                print(i, j, ((latents[i][0]-latents[j][0])**2).mean())
            else:
                mse = 0.
                for k in range(10):
                    mse += ((latents[i][0] - latents[i][k])**2).mean()
                print(i, i, mse / 10)
    # for i in range(len(y)):
    #     for j in range(len(y)):
    #         if i == j:
    #             l1 = 0.
    #             for _xp_0 in xp[i]:
    #                 l2 = net.content_encoder(_xp_0.unsqueeze(0)).squeeze()
    #                 l1 += float(((latent[0][i]-l2)**2).mean())
    #                 print(float(((latent[0][i]-l2)**2).mean()))
    #             print(f"i={i} / j={j}", l1 / len(xp[i]))
    #             continue
    #         print(f"i={i} / j={j}", float(((latent[0][i]-latent[0][j])**2).mean()))
    #     print("mse:", float(((y[i]-xp[i])**2).mean()))

    break

0 0 tensor(1.4039e-10, grad_fn=<DivBackward0>)
0 1 tensor(2.2091e-10, grad_fn=<MeanBackward0>)
0 2 tensor(1.8385e-10, grad_fn=<MeanBackward0>)
0 3 tensor(9.5200e-11, grad_fn=<MeanBackward0>)
0 4 tensor(3.6294e-10, grad_fn=<MeanBackward0>)
0 5 tensor(3.1136e-10, grad_fn=<MeanBackward0>)
0 6 tensor(1.1639e-10, grad_fn=<MeanBackward0>)
0 7 tensor(7.3829e-11, grad_fn=<MeanBackward0>)
0 8 tensor(4.8012e-11, grad_fn=<MeanBackward0>)
0 9 tensor(5.2832e-11, grad_fn=<MeanBackward0>)
1 0 tensor(2.2091e-10, grad_fn=<MeanBackward0>)
1 1 tensor(3.0423e-10, grad_fn=<DivBackward0>)
1 2 tensor(7.5801e-11, grad_fn=<MeanBackward0>)
1 3 tensor(3.8858e-10, grad_fn=<MeanBackward0>)
1 4 tensor(6.2034e-10, grad_fn=<MeanBackward0>)
1 5 tensor(2.4635e-10, grad_fn=<MeanBackward0>)
1 6 tensor(1.1119e-10, grad_fn=<MeanBackward0>)
1 7 tensor(2.3004e-10, grad_fn=<MeanBackward0>)
1 8 tensor(2.6573e-10, grad_fn=<MeanBackward0>)
1 9 tensor(3.0289e-10, grad_fn=<MeanBackward0>)
2 0 tensor(1.8385e-10, grad_fn=<MeanBackwa

In [21]:
((xp_0 - xp[:, 0])**2).mean()

tensor(0.)

In [53]:
import torch
import torch.nn.functional as F

In [80]:
def contrastive_loss(features, tau=0.1):
    """
    Compute contrastive loss for features in shape (t, n, c)
    :param features: Tensor of shape (t=num_classes, n=samples_per_class, c=feature_dim)
    :param tau: Temperature scaling parameter
    :return: Contrastive loss
    """
    t, n, c = features.size()
    features = features.view(t * n, c)
    square_distances = torch.cdist(features, features, p=2)

    labels = 1 - torch.eye(t*n).to(features.device)  # 同じ要素は0、それ以外は1
    for i in range(t):
        labels[i*n:(i+1)*n, i*n:(i+1)*n] = 0

    positive_loss = (1 - labels) * 0.5 * torch.pow(square_distances, 2)
    negative_loss = labels * 0.5 * torch.pow(torch.clamp(tau - square_distances, min=0.0), 2)
    
    loss = torch.sum(positive_loss + negative_loss) / (t * n * (t * n - 1))
    
    return loss


In [81]:
# Example tensor with random data
num_classes = 4
samples_per_class = 3
feature_dim = 10

# Randomly generate features
features = torch.randn(num_classes, samples_per_class, feature_dim)

# Compute loss
loss = contrastive_loss(features)
print("Contrastive Loss:", loss.item())

Contrastive Loss: 1.3802295923233032


In [82]:
contrastive_loss(torch.randn(4, 1, 1).repeat(1, 3, 10))

tensor(0.0002)