In [6]:
import torch as tc
import matplotlib.pyplot as plt
import numpy as np
import random

In [7]:
def generate_circle(y_size, x_size, y_origin, x_origin, radius):
    x_grid, y_grid = np.meshgrid(np.arange(x_size), np.arange(y_size))
    return (x_grid-x_origin)**2 + (y_grid-y_origin)**2 < radius **2

class CircleGenerator(tc.utils.data.Dataset):
    def __init__(self, y_size, x_size, min_radius, max_radius, num_circles):
        self.y_size, self.x_size = y_size, x_size
        self.min_radius, self.max_radius = min_radius, max_radius
        self.num_circles = num_circles
    
    def __len__(self):
        return self.num_circles
    
    def __getitem__(self, ids):
        x_origin = random.randrange(0, self.x_size)
        y_origin = random.randrange(0, self.y_size)
        radius = random.randrange(self.min_radius, self.max_radius)
        
        return generate_circle(self.y_size, self.x_size, y_origin, x_origin, radius)

In [10]:
dataset = CircleGenerator(256, 256, 10, 80, 16)
for i in range(len(dataset)):
    circle = dataset[i]
    # plt.imshow(circle, cmap='gray')
    # plt.show()

In [14]:
dataloader = tc.utils.data.DataLoader(dataset, batch_size=2, num_workers=0)

for batch in dataloader:
    print(f"Batch size: {batch.size}")
    print(f"Batch type: {type(batch)}")

Batch size: <built-in method size of Tensor object at 0x0000017A47FBBFB0>
Batch type: <class 'torch.Tensor'>
Batch size: <built-in method size of Tensor object at 0x0000017A47FBB650>
Batch type: <class 'torch.Tensor'>
Batch size: <built-in method size of Tensor object at 0x0000017A47FBB470>
Batch type: <class 'torch.Tensor'>
Batch size: <built-in method size of Tensor object at 0x0000017A47FBB530>
Batch type: <class 'torch.Tensor'>
Batch size: <built-in method size of Tensor object at 0x0000017A47FBB410>
Batch type: <class 'torch.Tensor'>
Batch size: <built-in method size of Tensor object at 0x0000017A47FBB3B0>
Batch type: <class 'torch.Tensor'>
Batch size: <built-in method size of Tensor object at 0x0000017A4B115FD0>
Batch type: <class 'torch.Tensor'>
Batch size: <built-in method size of Tensor object at 0x0000017A47FBB4D0>
Batch type: <class 'torch.Tensor'>
