In [0]:
import os
from google.colab import drive
drive.mount('/content/drive')
# 加载 google 云盘

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
!cat /usr/local/cuda/version.txt
from tensorflow.python.client import device_lib
device_lib.list_local_devices()
# CUDA Version 10.1.243

CUDA Version 10.1.243


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm   # 进度条
import warnings

# path = '/content/drive/My Drive/modelnet40_normal_resampled/modelnet40_shape_names.txt'
# open(path)
print(torch.cuda.is_available())
torch.backends.cudnn.enabled = False  
warnings.filterwarnings('ignore')

True


In [0]:
# ModelNet40 数据集加载
def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc

class ModelNetDataLoader(Dataset):
  def __init__(self, root,  npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
    self.root = root
    self.npoints = npoint
    self.uniform = uniform
    self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')

    self.cat = [line.rstrip() for line in open(self.catfile)]
    self.classes = dict(zip(self.cat, range(len(self.cat))))
    self.normal_channel = normal_channel

    shape_ids = {}
    shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
    shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]

    assert (split == 'train' or split == 'test')
    shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
    # list of (shape_name, shape_txt_file_path) tuple
    self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
                      in range(len(shape_ids[split]))]
    print('The size of %s data is %d'%(split,len(self.datapath)))

    self.cache_size = cache_size  # how many data points to cache in memory
    self.cache = {}  # from index to (point_set, cls) tuple

  def __len__(self):
    return len(self.datapath)

  def _get_item(self, index):
    if index in self.cache:
      point_set, cls = self.cache[index]
    else:
      fn = self.datapath[index]
      cls = self.classes[self.datapath[index][0]]
      cls = np.array([cls]).astype(np.int32)
      point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
      if self.uniform:
        point_set = self.farthest_point_sample(point_set, self.npoints)
      else:
        point_set = point_set[0:self.npoints,:]

      point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])

      if not self.normal_channel:
        point_set = point_set[:, 0:3]

      if len(self.cache) < self.cache_size:
        self.cache[index] = (point_set, cls)

    return point_set, cls

  def __getitem__(self, index):
    return self._get_item(index)

  def farthest_point_sample(point, npoint):
    """
    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:,:3]
    centroids = np.zeros((npoint,))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)
    point = point[centroids.astype(np.int32)]
    return point

# test
# data = ModelNetDataLoader('/content/drive/My Drive/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,)
# DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
# for point,label in DataLoader:
#   print(point.shape)
#   print(label.shape)

In [0]:
def index_points(points, idx):
  """

  Input:
      points: input points data, [B, N, C]
      idx: sample index data, [B, S]
  Return:
      new_points:, indexed points data, [B, S, C]
  """
  device = points.device
  B = points.shape[0]
  view_shape = list(idx.shape)
  view_shape[1:] = [1] * (len(view_shape) - 1)
  repeat_shape = list(idx.shape)
  repeat_shape[0] = 1
  batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
  new_points = points[batch_indices, idx, :]
  return new_points

def square_distance(src, dst):
  """
  Calculate Euclid distance between each two points.

  src^T * dst = xn * xm + yn * ym + zn * zm；
  sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
  sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
  dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
        = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

  Input:
      src: source points, [B, N, C]
      dst: target points, [B, M, C]
  Output:
      dist: per-point square distance, [B, N, M]
  """
  B, N, _ = src.shape
  _, M, _ = dst.shape
  dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
  dist += torch.sum(src ** 2, -1).view(B, N, 1)
  dist += torch.sum(dst ** 2, -1).view(B, 1, M)
  return dist


def query_ball_point(radius, nsample, xyz, new_xyz):
  """
  Input:
      radius: local region radius
      nsample: max sample number in local region
      xyz: all points, [B, N, 3]
      new_xyz: query points, [B, S, 3]
  Return:
      group_idx: grouped points index, [B, S, nsample]
  """
  device = xyz.device
  B, N, C = xyz.shape
  _, S, _ = new_xyz.shape
  group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
  sqrdists = square_distance(new_xyz, xyz)
  group_idx[sqrdists > radius ** 2] = N
  group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
  group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
  mask = group_idx == N
  group_idx[mask] = group_first[mask]
  return group_idx


def farthest_point_sample(xyz, npoint):
  """
  Input:
      xyz: pointcloud data, [B, N, 3]
      npoint: number of samples
  Return:
      centroids: sampled pointcloud index, [B, npoint]
  """
  device = xyz.device
  B, N, C = xyz.shape
  centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
  distance = torch.ones(B, N).to(device) * 1e10
  farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
  batch_indices = torch.arange(B, dtype=torch.long).to(device)
  # print(B,N,C)
  for i in range(npoint):
    centroids[:, i] = farthest
    centroid_temp = xyz[batch_indices, farthest, :].view(B, 1, 3)
    dist = torch.sum((xyz - centroid_temp) ** 2, -1)
    mask = dist < distance
    distance[mask] = dist[mask]
    farthest = torch.max(distance, -1)[1]
  return centroids

def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
  """
  Input:
      npoint:
      radius:
      nsample:
      xyz: input points position data, [B, N, 3]
      points: input points data, [B, N, D]
  Return:
      new_xyz: sampled points position data, [B, npoint, nsample, 3]
      new_points: sampled points data, [B, npoint, nsample, 3+D]
  """
  B, N, C = xyz.shape
  S = npoint
  fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
  torch.cuda.empty_cache()
  new_xyz = index_points(xyz, fps_idx)
  torch.cuda.empty_cache()
  idx = query_ball_point(radius, nsample, xyz, new_xyz)
  torch.cuda.empty_cache()
  grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
  torch.cuda.empty_cache()
  grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
  torch.cuda.empty_cache()

  if points is not None:
      grouped_points = index_points(points, idx)
      new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
  else:
      new_points = grouped_xyz_norm
  if returnfps:
      return new_xyz, new_points, grouped_xyz, fps_idx
  else:
      return new_xyz, new_points


def sample_and_group_all(xyz, points):
  """
  Input:
      xyz: input points position data, [B, N, 3]
      points: input points data, [B, N, D]
  Return:
      new_xyz: sampled points position data, [B, 1, 3]
      new_points: sampled points data, [B, 1, N, 3+D]
  """
  device = xyz.device
  B, N, C = xyz.shape
  new_xyz = torch.zeros(B, 1, C).to(device)
  grouped_xyz = xyz.view(B, 1, N, C)
  if points is not None:
      new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
  else:
      new_points = grouped_xyz
  return new_xyz, new_points

# SA层
class PointNetSetAbstraction(nn.Module):
  def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
    super(PointNetSetAbstraction, self).__init__()
    self.npoint = npoint
    self.radius = radius
    self.nsample = nsample
    self.mlp_convs = nn.ModuleList()
    self.mlp_bns = nn.ModuleList()
    last_channel = in_channel
    for out_channel in mlp:
        self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
        self.mlp_bns.append(nn.BatchNorm2d(out_channel))
        last_channel = out_channel
    self.group_all = group_all

  def forward(self, xyz, points):
    """
    Input:
        xyz: input points position data, [B, C, N]
        points: input points data, [B, D, N]
    Return:
        new_xyz: sampled points position data, [B, C, S]
        new_points_concat: sample points feature data, [B, D', S]
    """
    xyz = xyz.permute(0, 2, 1)
    if points is not None:
        points = points.permute(0, 2, 1)

    if self.group_all:
        new_xyz, new_points = sample_and_group_all(xyz, points)
    else:
        new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
    # new_xyz: sampled points position data, [B, npoint, C]
    # new_points: sampled points data, [B, npoint, nsample, C+D]
    new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
    for i, conv in enumerate(self.mlp_convs):
        bn = self.mlp_bns[i]
        new_points =  F.relu(bn(conv(new_points)))

    new_points = torch.max(new_points, 2)[0]
    new_xyz = new_xyz.permute(0, 2, 1)
    return new_xyz, new_points

# PointNet++
class Model(nn.Module):
  def __init__(self,num_class,normal_channel=True):
    super(Model, self).__init__()
    in_channel = 6 if normal_channel else 3
    self.normal_channel = normal_channel
    self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
    self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
    self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
    self.fc1 = nn.Linear(1024, 512)
    self.bn1 = nn.BatchNorm1d(512)
    self.drop1 = nn.Dropout(0.4)
    self.fc2 = nn.Linear(512, 256)
    self.bn2 = nn.BatchNorm1d(256)
    self.drop2 = nn.Dropout(0.4)
    self.fc3 = nn.Linear(256, num_class)

  def forward(self, xyz):
    B, _, _ = xyz.shape
    if self.normal_channel:
      norm = xyz[:, 3:, :]
      xyz = xyz[:, :3, :]
    else:
      norm = None
    l1_xyz, l1_points = self.sa1(xyz, norm)
    l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
    l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
    x = l3_points.view(B, 1024)
    x = self.drop1(F.relu(self.bn1(self.fc1(x))))
    x = self.drop2(F.relu(self.bn2(self.fc2(x))))
    x = self.fc3(x)
    x = F.log_softmax(x, -1)
    return x, l3_points

class Loss(nn.Module):
  def __init__(self):
    super(Loss, self).__init__()

  def forward(self, pred, target, trans_feat):
    return F.nll_loss(pred, target)
    

In [0]:
def shift_point_cloud(batch_data, shift_range=0.1):
  """ Randomly shift point cloud. Shift is per point cloud.
      Input:
        BxNx3 array, original batch of point clouds
      Return:
        BxNx3 array, shifted batch of point clouds
  """
  B, N, C = batch_data.shape
  shifts = np.random.uniform(-shift_range, shift_range, (B,3))
  for batch_index in range(B):
    batch_data[batch_index,:,:] += shifts[batch_index,:]
  return batch_data


def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
  """ Randomly scale the point cloud. Scale is per point cloud.
      Input:
          BxNx3 array, original batch of point clouds
      Return:
          BxNx3 array, scaled batch of point clouds
  """
  B, N, C = batch_data.shape
  scales = np.random.uniform(scale_low, scale_high, B)
  for batch_index in range(B):
    batch_data[batch_index,:,:] *= scales[batch_index]
  return batch_data

def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
  ''' batch_pc: BxNx3 '''
  for b in range(batch_pc.shape[0]):
    dropout_ratio =  np.random.random()*max_dropout_ratio # 0~0.875
    drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
    if len(drop_idx)>0:
        batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
  return batch_pc


# test function
def test(model, loader, vote_num = 0, num_class = 40):
  mean_correct = []
  class_acc = np.zeros((num_class,3))
  for j, data in tqdm(enumerate(loader), total=len(loader)):
    points, target = data
    target = target[:, 0]
    points = points.transpose(2, 1)
    points, target = points.cuda(), target.cuda()
    classifier = model.eval()
    if vote_num > 0:
      vote_pool = torch.zeros(target.size()[0],num_class).cuda()
      # vote_pool = torch.zeros(target.size()[0],num_class)
      for _ in range(vote_num):
        pred, _ = classifier(points)
        vote_pool += pred
      pred = vote_pool/vote_num
    else:
      pred, _ = classifier(points)
    pred_choice = pred.data.max(1)[1]
    for cat in np.unique(target.cpu()):
        classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
        class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
        class_acc[cat,1]+=1
    correct = pred_choice.eq(target.long().data).cpu().sum()
    mean_correct.append(correct.item()/float(points.size()[0]))
  class_acc[:,2] =  class_acc[:,0]/ class_acc[:,1]
  class_acc = np.mean(class_acc[:,2])
  instance_acc = np.mean(mean_correct)
  return instance_acc, class_acc

In [0]:
# train
'''parameters'''
batch_size = 24
train_epoch = 200
learning_rate = 0.001
decay_rate = 1e-4
num_point = 1024
num_class = 40
data_path = '/content/drive/My Drive/modelnet40_normal_resampled/'
model_save_path = '/content/drive/My Drive/best_model.pth'
gpu = '0'

os.environ["CUDA_VISIBLE_DEVICES"] = gpu

'''data loading'''
print('load dataset....')
train_dataset = ModelNetDataLoader(data_path,num_point,'train')
train_data_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,num_workers=4,shuffle=True)
test_dataset = ModelNetDataLoader(data_path,num_point,'test')
test_data_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,num_workers=4,shuffle=False)
print('data loaded\n')

'''model loading'''
print('model loading')
classifier = Model(num_class).cuda()
criterion = Loss().cuda()

try:
  checkpoint = torch.load(model_save_path)
  start_epoch = checkpoint['epoch']
  classifier.load_state_dict(checkpoint['model_state_dict'])
  print('Use pretrain model, current epoch: %d' % (start_epoch))
except:
  print('No existing model, starting training from scratch...')
  start_epoch = 0

optimizer = torch.optim.Adam(classifier.parameters(),lr=learning_rate,betas=(0.9, 0.999),eps=1e-08,weight_decay=decay_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)

global_epoch = 0
best_instance_acc = 0.0
best_class_acc = 0.0
mean_correct = []
print('model loaded\n')

load dataset....
The size of train data is 9843
The size of test data is 2468
data loaded

model loading
No existing model, starting training from scratch...
model loaded



In [0]:
'''training'''
print('start training')
for epoch in range(start_epoch, train_epoch):
  scheduler.step()
  for batch_id, data in tqdm(enumerate(train_data_loader, 0), total=len(train_data_loader), smoothing=0.9):
    points, target = data
    points = points.data.numpy()
    points = random_point_dropout(points)
    points[:,:, 0:3] = random_scale_point_cloud(points[:,:, 0:3])
    points[:,:, 0:3] = shift_point_cloud(points[:,:, 0:3])
    points = torch.Tensor(points)
    target = target[:, 0]

    points = points.transpose(2, 1)
    points, target = points.cuda(), target.cuda()
    optimizer.zero_grad()    

    classifier = classifier.train()
    pred, trans_feat = classifier(points)
    loss = criterion(pred, target.long(), trans_feat)
    pred_choice = pred.data.max(1)[1]
    correct = pred_choice.eq(target.long().data).cpu().sum()
    mean_correct.append(correct.item() / float(points.size()[0]))
    loss.backward()
    optimizer.step()
  
  train_instance_acc = np.mean(mean_correct)
  print('Train %d Accuracy: %f' % (epoch + 1,train_instance_acc))

  with torch.no_grad():
    global_epoch += 1
    instance_acc, class_acc = test(classifier.eval(), test_data_loader)
    if instance_acc >= best_instance_acc:
      best_instance_acc = instance_acc   
      best_epoch = epoch + 1

    if class_acc >= best_class_acc:
      best_class_acc = class_acc
    if instance_acc >= best_instance_acc:
      state = {
        'epoch': best_epoch,
        'instance_acc': instance_acc,
        'class_acc': class_acc,
        'model_state_dict': classifier.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
        }
      torch.save(state, model_save_path)
    print('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))

print('training fin!\n')

load dataset....
The size of train data is 9843
The size of test data is 2468
data loaded

model loading
Use pretrain model, current epoch: 51
model loaded

start training


100%|██████████| 411/411 [36:21<00:00,  5.31s/it]

Train 52 Accuracy: 0.905008



100%|██████████| 103/103 [08:53<00:00,  5.18s/it]


Best Instance Accuracy: 0.906311, Class Accuracy: 0.870559


100%|██████████| 411/411 [24:58<00:00,  3.65s/it]

Train 53 Accuracy: 0.906275



100%|██████████| 103/103 [08:35<00:00,  5.01s/it]


Best Instance Accuracy: 0.911893, Class Accuracy: 0.879669


100%|██████████| 411/411 [25:04<00:00,  3.66s/it]

Train 54 Accuracy: 0.908455



100%|██████████| 103/103 [08:18<00:00,  4.84s/it]

Best Instance Accuracy: 0.911893, Class Accuracy: 0.879669



100%|██████████| 411/411 [25:22<00:00,  3.70s/it]

Train 55 Accuracy: 0.907492



100%|██████████| 103/103 [07:58<00:00,  4.65s/it]

Best Instance Accuracy: 0.911893, Class Accuracy: 0.879669



100%|██████████| 411/411 [23:48<00:00,  3.47s/it]

Train 56 Accuracy: 0.906995



100%|██████████| 103/103 [07:40<00:00,  4.47s/it]


Best Instance Accuracy: 0.923058, Class Accuracy: 0.883634


100%|██████████| 411/411 [23:44<00:00,  3.47s/it]

Train 57 Accuracy: 0.907509



100%|██████████| 103/103 [07:59<00:00,  4.65s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.884882



100%|██████████| 411/411 [23:11<00:00,  3.39s/it]

Train 58 Accuracy: 0.907760



100%|██████████| 103/103 [07:44<00:00,  4.51s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.884882



100%|██████████| 411/411 [23:59<00:00,  3.50s/it]

Train 59 Accuracy: 0.908772



100%|██████████| 103/103 [08:24<00:00,  4.90s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [23:41<00:00,  3.46s/it]

Train 60 Accuracy: 0.908737



100%|██████████| 103/103 [07:30<00:00,  4.38s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [22:52<00:00,  3.34s/it]

Train 61 Accuracy: 0.907634



100%|██████████| 103/103 [07:23<00:00,  4.31s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [22:23<00:00,  3.27s/it]

Train 62 Accuracy: 0.907238



100%|██████████| 103/103 [07:34<00:00,  4.41s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [22:32<00:00,  3.29s/it]

Train 63 Accuracy: 0.907813



100%|██████████| 103/103 [07:34<00:00,  4.41s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [23:02<00:00,  3.36s/it]

Train 64 Accuracy: 0.908619



100%|██████████| 103/103 [07:26<00:00,  4.33s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [23:39<00:00,  3.45s/it]

Train 65 Accuracy: 0.909672



100%|██████████| 103/103 [07:23<00:00,  4.31s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [22:45<00:00,  3.32s/it]

Train 66 Accuracy: 0.910591



100%|██████████| 103/103 [08:02<00:00,  4.69s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [22:59<00:00,  3.36s/it]

Train 67 Accuracy: 0.910945



100%|██████████| 103/103 [07:17<00:00,  4.25s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [23:24<00:00,  3.42s/it]

Train 68 Accuracy: 0.911353



100%|██████████| 103/103 [07:25<00:00,  4.33s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [22:24<00:00,  3.27s/it]

Train 69 Accuracy: 0.911840



100%|██████████| 103/103 [07:19<00:00,  4.27s/it]

Best Instance Accuracy: 0.923058, Class Accuracy: 0.889218



100%|██████████| 411/411 [23:04<00:00,  3.37s/it]

Train 70 Accuracy: 0.912345



 23%|██▎       | 24/103 [01:59<04:50,  3.67s/it]

In [0]:
'''test'''
vote_num = 3
with torch.no_grad():
  instance_acc, class_acc = test(classifier.eval(), test_data_loader, vote_num = vote_num)
  print('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))

  0%|          | 0/103 [00:00<?, ?it/s]

KeyboardInterrupt: ignored