Testing functions and classes

In [0]:
# Serve per il mapping fra label sul file con le labels per canzone di test e l'indice della label che la rete considera.
LABELS_FROM_FILE = {'angry' : 0, 'calming' : 1, 'happy' : 3, 'normal' : 4, 'sad' : 5}

In [0]:
class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

def test_network_with_songs_data_return(net, test_dataset, batch_size):
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    net.train(False)
    net = net.to(DEVICE)

    test_songs_data = dict()
    for images, labels, paths in test_dataloader:
      torch.cuda.empty_cache()
      images = images.to(DEVICE)
      labels = labels.to(DEVICE)

      # Forward Pass
      outputs = net(images)

      # Get predictions
      _, preds = torch.max(outputs.data, 1)

      idx = 0;
      for pred in preds:

        image_name = paths[idx].split("/")[-1]
        song_idx = image_name.split("_")[0]
      
        if song_idx not in test_songs_data:
          test_songs_data[song_idx] = dict()
          test_songs_data[song_idx]["preds"] = np.zeros(NUM_CLASSES, dtype=int)
          # test_songs_data[song_idx]["outputs"] = []

        test_songs_data[song_idx]["preds"][pred] += 1
        # test_songs_data[song_idx]["outputs"].append(outputs[idx])

        idx += 1

      del labels
      del images
      del outputs

    return test_songs_data

def read_songs_labels(path):
  f = open(path,"r")
  lines = f.readlines()

  song_labels = dict()

  for line in lines:
    names = line.split(":")

    if names[0] not in song_labels:
      song_labels[names[0]] = []

    labels = names[1].replace(" ", "").split(",")
    labels[-1] = labels[-1][:-2]

    for label in labels:
      song_labels[names[0]].append(LABELS_FROM_FILE[label])
  
  return song_labels

def major_voting_analyze(songs_data, songs_labels):

  ordered_keys = sorted(songs_labels.keys())
  print(ordered_keys)
  prediction = dict()
  avg_outputs = dict()
  corrects = 0

  for key in songs_data.keys():
    num_slices = 0
    for value in songs_data[key]["preds"]:
      num_slices += value


    for value in songs_data[key]["preds"]:
      value = (float) (value/num_slices)
      # if value > 0.5:
      #   prediction[key] = idx_pred

    max = 0
    idx_pred = 0
    idx_max = 0
    for value in songs_data[key]["preds"]:
      if value > max:
        max = value
        idx_max = idx_pred
      idx_pred += 1

    prediction[key] = idx_max

    # for output in songs_data[key]["outputs"]:
    #   if sum_outputs is None:
    #     sum_outputs = output
    #   else:
    #     sum_outputs += output

    # avg_ouputs = sum_outputs/num_slices

  idx = 0
  for song in prediction:
    if prediction[song] in songs_labels[ordered_keys[idx]]:
      corrects += 1
    print("Prediction for song {} - {}: ; labels: {}".format(song, ordered_keys[idx], prediction, songs_labels[ordered_keys[idx]]))
  idx += 1

  test_accuracy = (float) (corrects / idx)
  print("Test accuracy: {}".format(test_accuracy))

def get_test_dataset(test_data_dir):
    eval_transform = transforms.Compose([
          transforms.Resize(224),
          transforms.CenterCrop(224),
          transforms.ToTensor()
          ])
    
    if not os.path.isdir('./AIML_project'):
        !git clone https://github.com/anphetamina/AIML_project.git
    
    test_dataset = ImageFolderWithPaths(test_data_dir, transform=eval_transform)

    return test_dataset

Code example

In [0]:
torch.cuda.empty_cache()

TEST_DATA_DIR = 'AIML_project/CAL500_test_sliced_spectrograms'
test_dataset = get_test_dataset(TEST_DATA_DIR)
print('test set {}'.format(len(test_dataset_copy)))

# net extracted by training
songs_data = test_network_with_songs_data_return(net, test_dataset, 1)

print(songs_data)

songs_labels = read_songs_labels("AIML_project/songs_filtered_with_labels.txt")

print(songs_labels)

major_voting_analyze(songs_data, songs_labels)