In [0]:
!wget http://bradheintz.com/kaggle/plant-seedlings-classification.zip
!unzip plant-seedlings-classification.zip
!mkdir data
!mkdir data/train
!mkdir data/test
!mv train.zip data/train
!unzip data/train/train.zip
!mv test.zip data/test
!unzip data/test/test.zip

In [None]:
from imageio import imread
import numpy as np
from PIL import Image
from io import BytesIO



class SeedlingTestDataset(torch.utils.data.Dataset):

    def __init__(self, path_to_test_data='data/test', transform=None):
        self.transform = transform
        self.data, self.datasize = self.build_dataset_from_path(path_to_test_data)
        self.filenames = sorted(self.data.keys())

    def build_dataset_from_path(self, test_data_path):
        data = {}
        for item in listdir(test_data_path):
            file_path = join(test_data_path, item)
            if isfile(file_path) and 'png' in file_path:
                data[item] = file_path
        return data, len(data)

    def __len__(self):
        return self.datasize

    def __getitem__(self, index):
        key = self.filenames[index]
        full_path = self.data[key]

        f = open(full_path, 'rb')
        img = Image.open(BytesIO(f.read()))
        f.close()
        if self.transform is not None:
            img = self.transform(img)

        return img, key

def get_test_loader(args, kwargs):
    incoming_transforms = transforms.Compose([
        transforms.Resize(100),
        transforms.RandomCrop(100),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        ])
    dataset = SeedlingTestDataset(path_to_test_data='data/test', transform=incoming_transforms)
    return torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False, **kwargs)

def test_model(model, device, kwargs):
    model.eval()
    classes = get_training_loader(args, kwargs).dataset.classes # TODO this is horrible
    loader = get_test_loader(args, kwargs)
    filenames = loader.dataset.filenames

    outfile = open('submission.csv', 'w')
    outfile.write('file,species\n')
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            output = model(data)
            score, pred = torch.max(output, 1)
            outfile.write('{},{}\n'.format(target[0], classes[pred.item()])) # batches of 1