In [None]:
# Create a STL10 dataset by inheriting Pytorch's exisitng STL10 
# and re-defining the __getitem__ method
class RotateSTL10(datasets.STL10):
    # Define a list of different angles to roate the image by
    all_perms = [0, 45, 90, 135, 180, 225, 270]
    
    def __getitem__(self, index):
            # Select image using index
            img = self.data[index]
            
            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(np.transpose(img, (1, 2, 0)))
            
            # Randomly select an angle from the list to rotate the image by
            rand_int = random.randint(0, len(self.all_perms) - 1)
            img = FT.rotate(img, angle=self.all_perms[rand_int])

            # Add additional transforms
            if self.transform is not None:
                img = self.transform(img)

            # Return roated image and the index of the selected angle
            return img, rand_int

In [None]:
# Create a STL10 dataset by inheriting Pytorch's exisitng STL10 
# and re-defining the __getitem__ method
class ShuffleSTL10(datasets.STL10):
    
    # Define the hight and width of the "puzzle" grid !
    puzzle_size = 3
    # Set the maximum number of permutations
    max_perms = 100
    
    # Determine all possible permutations of the puzzle pieces
    iter_array = itertools.permutations(np.arange(puzzle_size**2))
    all_perms = []
    for arr in iter_array:
        all_perms.append(torch.tensor([arr]))
        
        if len(all_perms) == max_perms:
            break

    def __getitem__(self, index):
            # Select image using index
            img = self.data[index]
            
            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(np.transpose(img, (1, 2, 0)))
            if self.transform is not None:
                img = self.transform(img)
                
            # Determine number of pixels per puzzel piece
            img_size = img.shape[-1]
            puzzle_sections = self.puzzle_size**2
            
            # Use Pytorch Shuffle and UnShuffle to move pieces around
            unshuffle = nn.PixelUnshuffle(img_size//self.puzzle_size)
            shuffle = nn.PixelShuffle(img_size//self.puzzle_size)
            
            # Randomly select one permutation of the puzzle
            rand_int = random.randint(0, len(self.all_perms) - 1)
            perm = self.all_perms[rand_int]
            
            # Shuffle the puzzle pieces
            img_out = unshuffle(img.unsqueeze(0))
            img_out = img_out.reshape(1, img.shape[0], -1, puzzle_sections)
            img_out = shuffle(img_out[:, :, :, perm].reshape(1, -1, 
                                                                  self.puzzle_size, 
                                                                  self.puzzle_size))

            return img_out.squeeze(0), rand_int

In [None]:
transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])])