In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
import glob
import torch.nn.functional as F
import provider   
import argparse
import os
from tqdm import tqdm

from tensorboardX import SummaryWriter
from code.utils.cluster.IID_losses import IID_loss

In [2]:
parser = argparse.ArgumentParser(description='byol-pointnet_baseline')
# parser.add_argument('--image_folder', type=str, required = True, help='path to your folder of images for self-supervised learning')
parser.add_argument('--train_dir', default="/data2/ABC2/data_raw_clustering/pc_correct_num", required=False, help='Training data root.')
parser.add_argument('--csv_dir', default="/data2/ABC2/GT-final.csv", required=False, help='GT label data root.')
parser.add_argument('--taskType', type=str, default='Clustering', help='Type of task.')
parser.add_argument('--batchSize', type=int, default=10, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=12)
parser.add_argument("--lamb", type=float, default=1.0)
args = parser.parse_args([])

In [3]:
class ABC2Dataset_pc(Dataset):
    """ABC2Dataset"""

    def __init__(self, csv_file, root_dir, task_type, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with ground truth.
            root_dir (string): Directory with all the model files.
            task_type (string): Task type.
        """

        # self.ground_truth_file = pd.read_csv(csv_file, header=0)
        self.model_id = glob.glob(os.path.join(root_dir, '*'))
        self.root_dir = root_dir
        self.task_type = task_type
        # print("model_id:", len(self.model_id))

    def __len__(self):
        return len(self.model_id)

    def __getitem__(self, idx):

        "Clustering"
        if self.task_type == "Clustering":
            model_ID = self.model_id[idx].split('/')[-1]
  
            obj_path = glob.glob(os.path.join(self.model_id[idx], '*.pt'))
#             print(obj_path)
            point_cloud_normalized = torch.load(obj_path[0])
            # print("wen_debug dataloader", point_cloud_normalized.shape)
            pc1 = point_cloud_normalized
#             point_cloud_normalized = point_cloud_normalized.unsqueeze(0)
#             # print("wen_debug dataloader", point_cloud_normalized.shape)
#             pc1 = provider.random_point_dropout(point_cloud_normalized)
#             pc1 = provider.random_scale_point_cloud(pc1)
#             pc1 = provider.shift_point_cloud(pc1)

            point_cloud_normalized = point_cloud_normalized.unsqueeze(0)
            pc2 = provider.random_point_dropout(point_cloud_normalized)
            pc2 = provider.random_scale_point_cloud(pc2)
            pc2 = provider.shift_point_cloud(pc2)

            return pc1.transpose(0,1), pc2.squeeze(0).transpose(0,1)
        
class ABC2Dataset_pc_test(Dataset):
    """ABC2Dataset"""

    def __init__(self, csv_file, root_dir, task_type, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with ground truth.
            root_dir (string): Directory with all the model files.
            task_type (string): Task type.
        """

        # self.ground_truth_file = pd.read_csv(csv_file, header=0)
        self.model_id = glob.glob(os.path.join(root_dir, '*'))
        self.root_dir = root_dir
        self.task_type = task_type
        # print("model_id:", len(self.model_id))

    def __len__(self):
        return len(self.model_id)

    def __getitem__(self, idx):

        "Clustering"
        if self.task_type == "Clustering":
            model_ID = self.model_id[idx].split('/')[-1]
  
            obj_path = glob.glob(os.path.join(self.model_id[idx], '*.pt'))
            
            point_cloud_normalized = torch.load(obj_path[0])
            # print("wen_debug dataloader", point_cloud_normalized.shape)

            return point_cloud_normalized.transpose(0,1), model_ID

In [4]:
class PointNetEncoder(nn.Module):
    def __init__(self, global_feat=True, channel=3):
        super(PointNetEncoder, self).__init__()
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

    def forward(self, x):
        B, D, N = x.size()
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        return x
        

class pointnet_cls(nn.Module):
    def __init__(self, k=2000, normal_channel=False):
        super(pointnet_cls, self).__init__()
        if normal_channel:
            channel = 6
        else:
            channel = 3
        self.feat = PointNetEncoder(global_feat=True, channel=channel)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.4)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        x = F.softmax(x, dim=1)
        return x

In [5]:
train_data = ABC2Dataset_pc(args.csv_dir, args.train_dir, args.taskType)
train_loader = DataLoader(train_data, batch_size=args.batchSize, num_workers=args.workers, shuffle=True)

In [6]:
len(train_data)

22968

In [7]:
model = pointnet_cls(k=256)

opt = torch.optim.Adam(model.parameters(), lr=3e-4)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model = model.train()
writer = SummaryWriter('iic_pointnet_baseline')


In [8]:
for epoch in range(100):
    Loss = 0
    for batch_id, data in tqdm(enumerate(train_loader, 0), total=len(train_loader), smoothing=0):
        pc1, pc2 = data
        pc1, pc2 = pc1.cuda(), pc2.cuda()
        
        x_outs = model(pc1)
        x_tf_outs = model(pc2)
#         print(pc1.shape, pc2.shape, x_outs.shape, x_tf_outs.shape)
        loss, loss_no_lamb = IID_loss(x_outs, x_tf_outs, lamb=args.lamb)
#         print(loss, loss_no_lamb)
        Loss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
    
    writer.add_scalars('train_loss', {'train_loss': Loss,}, epoch)
    print("epoch, Training loss: ", epoch, Loss)
# save your improved network
torch.save(model.state_dict(), './model_iic_k256.pt')

writer.export_scalars_to_json("./iic_pointnet_baseline.json")
writer.close()

100%|██████████| 2297/2297 [01:04<00:00, 35.57it/s]

epoch, Training loss:  0 -2563.2268825969077



100%|██████████| 2297/2297 [01:07<00:00, 34.17it/s]

epoch, Training loss:  1 -4050.952571630478



100%|██████████| 2297/2297 [01:08<00:00, 33.74it/s]

epoch, Training loss:  2 -4500.323747038841



100%|██████████| 2297/2297 [01:08<00:00, 33.48it/s]

epoch, Training loss:  3 -4739.919529557228



100%|██████████| 2297/2297 [01:08<00:00, 33.46it/s]

epoch, Training loss:  4 -4914.374559640884



100%|██████████| 2297/2297 [01:08<00:00, 33.39it/s]

epoch, Training loss:  5 -5037.101469635963



100%|██████████| 2297/2297 [01:09<00:00, 32.97it/s]

epoch, Training loss:  6 -5179.99608707428



100%|██████████| 2297/2297 [01:09<00:00, 32.87it/s]

epoch, Training loss:  7 -5329.654037833214



100%|██████████| 2297/2297 [01:10<00:00, 32.75it/s]

epoch, Training loss:  8 -5500.192034482956



100%|██████████| 2297/2297 [01:10<00:00, 32.68it/s]

epoch, Training loss:  9 -5667.287063121796



100%|██████████| 2297/2297 [01:10<00:00, 32.52it/s]

epoch, Training loss:  10 -5791.73764705658



100%|██████████| 2297/2297 [01:10<00:00, 32.36it/s]

epoch, Training loss:  11 -5895.33956861496



100%|██████████| 2297/2297 [01:10<00:00, 32.42it/s]

epoch, Training loss:  12 -5989.767723798752



100%|██████████| 2297/2297 [01:11<00:00, 32.21it/s]

epoch, Training loss:  13 -6060.681067466736



100%|██████████| 2297/2297 [01:10<00:00, 32.47it/s]

epoch, Training loss:  14 -6123.079252719879



100%|██████████| 2297/2297 [01:10<00:00, 32.41it/s]

epoch, Training loss:  15 -6172.392023801804



100%|██████████| 2297/2297 [01:10<00:00, 32.42it/s]

epoch, Training loss:  16 -6215.07114648819



100%|██████████| 2297/2297 [01:10<00:00, 32.44it/s]

epoch, Training loss:  17 -6249.6136820316315



100%|██████████| 2297/2297 [01:10<00:00, 32.50it/s]

epoch, Training loss:  18 -6281.443498849869



100%|██████████| 2297/2297 [01:10<00:00, 32.50it/s]

epoch, Training loss:  19 -6314.632014989853



100%|██████████| 2297/2297 [01:10<00:00, 32.48it/s]

epoch, Training loss:  20 -6327.095587730408



100%|██████████| 2297/2297 [01:10<00:00, 32.55it/s]

epoch, Training loss:  21 -6352.3524787425995



100%|██████████| 2297/2297 [01:10<00:00, 32.36it/s]

epoch, Training loss:  22 -6375.473245859146



100%|██████████| 2297/2297 [01:10<00:00, 32.43it/s]

epoch, Training loss:  23 -6383.894538164139



100%|██████████| 2297/2297 [01:10<00:00, 32.49it/s]

epoch, Training loss:  24 -6402.215129852295



100%|██████████| 2297/2297 [01:11<00:00, 32.29it/s]

epoch, Training loss:  25 -6417.891879796982



100%|██████████| 2297/2297 [01:10<00:00, 32.43it/s]

epoch, Training loss:  26 -6426.865793466568



100%|██████████| 2297/2297 [01:10<00:00, 32.50it/s]

epoch, Training loss:  27 -6434.863570213318



100%|██████████| 2297/2297 [01:10<00:00, 32.53it/s]

epoch, Training loss:  28 -6437.522647857666



100%|██████████| 2297/2297 [01:10<00:00, 32.52it/s]

epoch, Training loss:  29 -6455.5113253593445



100%|██████████| 2297/2297 [01:10<00:00, 32.57it/s]

epoch, Training loss:  30 -6465.932829856873



100%|██████████| 2297/2297 [01:10<00:00, 32.37it/s]

epoch, Training loss:  31 -6472.186025619507



100%|██████████| 2297/2297 [01:10<00:00, 32.44it/s]

epoch, Training loss:  32 -6482.573924541473



100%|██████████| 2297/2297 [01:11<00:00, 32.34it/s]

epoch, Training loss:  33 -6489.924809455872



100%|██████████| 2297/2297 [01:11<00:00, 32.28it/s]

epoch, Training loss:  34 -6488.994112491608



100%|██████████| 2297/2297 [01:11<00:00, 32.32it/s]

epoch, Training loss:  35 -6492.583439826965



100%|██████████| 2297/2297 [01:10<00:00, 32.45it/s]

epoch, Training loss:  36 -6504.955135822296



100%|██████████| 2297/2297 [01:11<00:00, 32.29it/s]

epoch, Training loss:  37 -6509.662816762924



100%|██████████| 2297/2297 [01:10<00:00, 32.42it/s]

epoch, Training loss:  38 -6521.078746080399



100%|██████████| 2297/2297 [01:11<00:00, 32.33it/s]

epoch, Training loss:  39 -6520.000031471252



100%|██████████| 2297/2297 [01:10<00:00, 32.44it/s]

epoch, Training loss:  40 -6516.431742668152



100%|██████████| 2297/2297 [01:10<00:00, 32.42it/s]

epoch, Training loss:  41 -6530.171114683151



100%|██████████| 2297/2297 [01:10<00:00, 32.45it/s]

epoch, Training loss:  42 -6537.167949438095



100%|██████████| 2297/2297 [01:10<00:00, 32.42it/s]

epoch, Training loss:  43 -6541.7511484622955



100%|██████████| 2297/2297 [01:11<00:00, 32.35it/s]

epoch, Training loss:  44 -6533.436820983887



100%|██████████| 2297/2297 [01:10<00:00, 32.36it/s]

epoch, Training loss:  45 -6545.088688135147



100%|██████████| 2297/2297 [01:11<00:00, 32.28it/s]

epoch, Training loss:  46 -6546.35786151886



100%|██████████| 2297/2297 [01:10<00:00, 32.57it/s]

epoch, Training loss:  47 -6562.988699197769



100%|██████████| 2297/2297 [01:10<00:00, 32.45it/s]

epoch, Training loss:  48 -6554.02494263649



100%|██████████| 2297/2297 [01:10<00:00, 32.48it/s]

epoch, Training loss:  49 -6558.010581493378



100%|██████████| 2297/2297 [01:10<00:00, 32.42it/s]

epoch, Training loss:  50 -6558.883900403976



100%|██████████| 2297/2297 [01:11<00:00, 32.29it/s]

epoch, Training loss:  51 -6567.223198413849



100%|██████████| 2297/2297 [01:10<00:00, 32.56it/s]

epoch, Training loss:  52 -6566.865798473358



100%|██████████| 2297/2297 [01:10<00:00, 32.41it/s]

epoch, Training loss:  53 -6564.303084611893



100%|██████████| 2297/2297 [01:10<00:00, 32.37it/s]

epoch, Training loss:  54 -6565.314704418182



100%|██████████| 2297/2297 [01:11<00:00, 32.33it/s]

epoch, Training loss:  55 -6573.282817363739



100%|██████████| 2297/2297 [01:10<00:00, 32.47it/s]

epoch, Training loss:  56 -6576.23489689827



100%|██████████| 2297/2297 [01:10<00:00, 32.43it/s]

epoch, Training loss:  57 -6573.928010702133



100%|██████████| 2297/2297 [01:10<00:00, 32.39it/s]

epoch, Training loss:  58 -6584.21563243866



100%|██████████| 2297/2297 [01:10<00:00, 32.35it/s]


epoch, Training loss:  59 -6575.954509019852


100%|██████████| 2297/2297 [01:11<00:00, 32.20it/s]

epoch, Training loss:  60 -6586.907736539841



100%|██████████| 2297/2297 [01:10<00:00, 32.45it/s]

epoch, Training loss:  61 -6579.751161813736



100%|██████████| 2297/2297 [01:10<00:00, 32.43it/s]

epoch, Training loss:  62 -6591.7231528759



100%|██████████| 2297/2297 [01:10<00:00, 32.57it/s]

epoch, Training loss:  63 -6584.721755027771



100%|██████████| 2297/2297 [01:10<00:00, 32.42it/s]

epoch, Training loss:  64 -6597.804543495178



100%|██████████| 2297/2297 [01:11<00:00, 32.32it/s]

epoch, Training loss:  65 -6594.175172328949



100%|██████████| 2297/2297 [01:10<00:00, 32.44it/s]

epoch, Training loss:  66 -6608.1840443611145



100%|██████████| 2297/2297 [01:10<00:00, 32.49it/s]

epoch, Training loss:  67 -6612.003643751144



100%|██████████| 2297/2297 [01:10<00:00, 32.50it/s]

epoch, Training loss:  68 -6601.7912774086



100%|██████████| 2297/2297 [01:10<00:00, 32.45it/s]

epoch, Training loss:  69 -6605.23509812355



100%|██████████| 2297/2297 [01:10<00:00, 32.43it/s]

epoch, Training loss:  70 -6596.798934459686



100%|██████████| 2297/2297 [01:11<00:00, 32.25it/s]

epoch, Training loss:  71 -6608.333466053009



100%|██████████| 2297/2297 [01:10<00:00, 32.44it/s]

epoch, Training loss:  72 -6602.05374956131



100%|██████████| 2297/2297 [01:10<00:00, 32.36it/s]

epoch, Training loss:  73 -6604.045810461044



100%|██████████| 2297/2297 [01:11<00:00, 32.32it/s]

epoch, Training loss:  74 -6600.523146152496



100%|██████████| 2297/2297 [01:11<00:00, 32.27it/s]

epoch, Training loss:  75 -6601.204106807709



100%|██████████| 2297/2297 [01:10<00:00, 32.37it/s]

epoch, Training loss:  76 -6613.416275262833



100%|██████████| 2297/2297 [01:10<00:00, 32.58it/s]

epoch, Training loss:  77 -6607.604194641113



100%|██████████| 2297/2297 [01:10<00:00, 32.69it/s]

epoch, Training loss:  78 -6613.619949579239



100%|██████████| 2297/2297 [01:10<00:00, 32.71it/s]

epoch, Training loss:  79 -6612.437024354935



100%|██████████| 2297/2297 [01:10<00:00, 32.64it/s]

epoch, Training loss:  80 -6616.663692474365



100%|██████████| 2297/2297 [01:10<00:00, 32.66it/s]

epoch, Training loss:  81 -6614.602504491806



100%|██████████| 2297/2297 [01:10<00:00, 32.72it/s]

epoch, Training loss:  82 -6614.7091472148895



100%|██████████| 2297/2297 [01:09<00:00, 32.83it/s]

epoch, Training loss:  83 -6625.794642925262



100%|██████████| 2297/2297 [01:10<00:00, 32.73it/s]

epoch, Training loss:  84 -6620.151348829269



100%|██████████| 2297/2297 [01:10<00:00, 32.76it/s]

epoch, Training loss:  85 -6624.310397148132



100%|██████████| 2297/2297 [01:10<00:00, 32.81it/s]

epoch, Training loss:  86 -6626.656254291534



100%|██████████| 2297/2297 [01:09<00:00, 32.92it/s]

epoch, Training loss:  87 -6623.281843185425



100%|██████████| 2297/2297 [01:09<00:00, 32.87it/s]

epoch, Training loss:  88 -6615.149544477463



100%|██████████| 2297/2297 [01:10<00:00, 32.81it/s]

epoch, Training loss:  89 -6628.613304138184



100%|██████████| 2297/2297 [01:09<00:00, 32.96it/s]

epoch, Training loss:  90 -6622.982315778732



100%|██████████| 2297/2297 [01:09<00:00, 33.14it/s]

epoch, Training loss:  91 -6627.8914659023285



100%|██████████| 2297/2297 [01:09<00:00, 33.23it/s]

epoch, Training loss:  92 -6630.112178325653



100%|██████████| 2297/2297 [01:09<00:00, 33.22it/s]

epoch, Training loss:  93 -6619.426297664642



100%|██████████| 2297/2297 [01:09<00:00, 33.27it/s]

epoch, Training loss:  94 -6622.13078045845



100%|██████████| 2297/2297 [01:08<00:00, 33.34it/s]

epoch, Training loss:  95 -6631.590719461441



100%|██████████| 2297/2297 [01:08<00:00, 33.37it/s]

epoch, Training loss:  96 -6623.379580974579



100%|██████████| 2297/2297 [01:09<00:00, 33.29it/s]

epoch, Training loss:  97 -6639.638449668884



100%|██████████| 2297/2297 [01:09<00:00, 33.14it/s]

epoch, Training loss:  98 -6625.514168024063



100%|██████████| 2297/2297 [01:09<00:00, 33.22it/s]

epoch, Training loss:  99 -6630.938138484955





In [9]:
class PointNetEncoder(nn.Module):
    def __init__(self, global_feat=True, channel=3):
        super(PointNetEncoder, self).__init__()
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

    def forward(self, x):
        B, D, N = x.size()
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        return x
        

class pointnet_cls(nn.Module):
    def __init__(self, k=2000, normal_channel=False):
        super(pointnet_cls, self).__init__()
        if normal_channel:
            channel = 6
        else:
            channel = 3
        self.feat = PointNetEncoder(global_feat=True, channel=channel)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.4)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.feat(x)
        fea = x
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        x = F.softmax(x, dim=1)
        return fea, x

In [11]:
train_data = ABC2Dataset_pc_test(args.csv_dir, args.train_dir, args.taskType)
train_loader = DataLoader(train_data, batch_size=args.batchSize, num_workers=args.workers, shuffle=False)

In [62]:
model = pointnet_cls(k=32)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.load_state_dict(torch.load('./model_iic_k32.pt'))
model.eval()

pointnet_cls(
  (feat): PointNetEncoder(
    (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
    (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=32, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
)

In [63]:
modelid_list = []
fea_list = []
label_list = []
with torch.no_grad():
    for i, data in enumerate(train_loader):
        pc, id_str = data
        fea, pred = model(pc.cuda())
        label = torch.argmax(pred, 1)
        print(i, fea.shape, pc.shape,pred.shape)
        fea_list.append(fea.detach().cpu().numpy())
        modelid_list.append(id_str)
        label_list.append(label.detach().cpu().numpy())


0 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
3 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
4 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
5 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
6 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
7 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
8 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
9 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
10 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
11 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
12 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
13 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Siz

147 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
148 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
149 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
150 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
151 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
152 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
153 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
154 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
155 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
156 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
157 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
158 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
159 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
160 torch.Size([10, 1024]) torch.Size(

264 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
265 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
266 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
267 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
268 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
269 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
270 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
271 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
272 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
273 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
274 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
275 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
276 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
277 torch.Size([10, 1024]) torch.Size(

380 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
381 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
382 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
383 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
384 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
385 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
386 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
387 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
388 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
389 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
390 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
391 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
392 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
393 torch.Size([10, 1024]) torch.Size(

499 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
500 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
501 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
502 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
503 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
504 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
505 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
506 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
507 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
508 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
509 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
510 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
511 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
512 torch.Size([10, 1024]) torch.Size(

616 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
617 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
618 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
619 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
620 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
621 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
622 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
623 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
624 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
625 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
626 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
627 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
628 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
629 torch.Size([10, 1024]) torch.Size(

733 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
734 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
735 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
736 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
737 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
738 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
739 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
740 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
741 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
742 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
743 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
744 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
745 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
746 torch.Size([10, 1024]) torch.Size(

851 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
852 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
853 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
854 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
855 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
856 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
857 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
858 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
859 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
860 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
861 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
862 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
863 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
864 torch.Size([10, 1024]) torch.Size(

968 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
969 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
970 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
971 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
972 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
973 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
974 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
975 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
976 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
977 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
978 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
979 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
980 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
981 torch.Size([10, 1024]) torch.Size(

1087 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1088 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1089 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1090 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1091 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1092 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1093 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1094 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1095 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1096 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1097 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1098 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1099 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1100 torch.Size([10, 1024

1205 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1206 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1207 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1208 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1209 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1210 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1211 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1212 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1213 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1214 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1215 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1216 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1217 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1218 torch.Size([10, 1024

1322 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1323 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1324 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1325 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1326 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1327 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1328 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1329 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1330 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1331 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1332 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1333 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1334 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1335 torch.Size([10, 1024

1439 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1440 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1441 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1442 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1443 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1444 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1445 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1446 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1447 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1448 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1449 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1450 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1451 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1452 torch.Size([10, 1024

1553 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1554 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1555 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1556 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1557 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1558 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1559 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1560 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1561 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1562 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1563 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1564 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1565 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1566 torch.Size([10, 1024

1669 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1670 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1671 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1672 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1673 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1674 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1675 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1676 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1677 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1678 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1679 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1680 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1681 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1682 torch.Size([10, 1024

1787 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1788 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1789 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1790 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1791 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1792 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1793 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1794 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1795 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1796 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1797 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1798 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1799 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1800 torch.Size([10, 1024

1906 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1907 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1908 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1909 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1910 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1911 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1912 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1913 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1914 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1915 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1916 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1917 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1918 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
1919 torch.Size([10, 1024

2024 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2025 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2026 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2027 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2028 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2029 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2030 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2031 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2032 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2033 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2034 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2035 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2036 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2037 torch.Size([10, 1024

2141 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2142 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2143 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2144 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2145 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2146 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2147 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2148 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2149 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2150 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2151 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2152 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2153 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2154 torch.Size([10, 1024

2256 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2257 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2258 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2259 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2260 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2261 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2262 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2263 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2264 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2265 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2266 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2267 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2268 torch.Size([10, 1024]) torch.Size([10, 3, 4096]) torch.Size([10, 32])
2269 torch.Size([10, 1024

In [64]:
import numpy as np

In [65]:
modelid_list2 = []
for i in range(len(modelid_list)):
    modelid_list2 += modelid_list[i]

In [66]:
print(len(fea_list))
fea_list2 = np.concatenate(fea_list,0)
print(fea_list2.shape)

2297
(22968, 1024)


In [67]:
np.save("latent_space_iic_K32", fea_list2)

In [68]:
label_list2 = np.concatenate((label_list),0)

In [69]:
np.unique(label_list2)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

In [70]:
cluster_id_class_file = open("cluster_id_class_iic_pc_baseline_K32.txt", 'w')

for i in range(len(modelid_list2)):
    cluster_id_class_file.write(modelid_list2[i]+", "+str(label_list2[i]))
    cluster_id_class_file.write('\n')
cluster_id_class_file.close()