In [None]:
# @title Data retrieval
def download_image(fname, url, expected_md5):
    """
    Downloads an image file from the given URL and saves it locally.

    Inputs:
    - fname (str): The local filename/path to save the downloaded image.
    - url (str): The URL from which to download the image.
    - expected_md5 (str): The expected MD5 checksum to verify the integrity of the downloaded data.
    """
    if not os.path.isfile(fname):
        try:
            # Attempt to download the file
            r = requests.get(url) # Make a GET request to the specified URL
        except requests.ConnectionError:
            # Handle connection errors during the download
            print("!!! Failed to download data !!!")
        else:
            # No connection errors, proceed to check the response
            if r.status_code != requests.codes.ok:
                # Check if the HTTP response status code indicates a successful download
                print("!!! Failed to download data !!!")
            elif hashlib.md5(r.content).hexdigest() != expected_md5:
                # Verify the integrity of the downloaded file using MD5 checksum
                print("!!! Data download appears corrupted !!!")
            else:
                # If download is successful and data is not corrupted, save the file
                with open(fname, "wb") as fid:
                    fid.write(r.content) # Write the downloaded content to a file

# Variables for file and download URL
fnames = ["img_1235.jpg", "image_augmentation.png"]  # The names of the files to be downloaded
urls = ["https://osf.io/kv5bx/download", "https://osf.io/fqwsr/download"]  # URLs from where the files will be downloaded
expected_md5s = ["920ae567f707bfee0be29dc854f804ed", "f4f1ebee1470a7e2d7662eec1d193ba2"] # MD5 hashes for verifying files integrity

for fname, url, expected_md5 in zip(fnames, urls, expected_md5s):
    download_image(fname, url, expected_md5)

In [None]:
# @title Parameters 
# @markdown

class Parameters:
    def __init__(self):
        # Library to use
        self.libname = 'library'
        self.set_rendering_params()
        self.set_spline_params()
        self.set_image_model_params()
        self.set_mcmc_params()
        self.set_search_params()
        
    def set_rendering_params(self):
        self.imsize = torch.Size([105, 105]) # image size

        ## ink-add parameters
        self.ink_pp = 2. # amount of ink per point
        self.ink_max_dist = 2. # distance between points to which you get full ink

        ## broadening parameters
        self.ink_ncon = 2 # number of convolutions
        self.ink_a = 0.5 # parameter 1
        self.ink_b = 6. # parameter 2
        self.broaden_mode = 'Lake' # broadening version (must be either "Lake" or "Hinton")

        ## blurring parameters
        self.fsize = 11 # convolution size for blurring

    def set_spline_params(self):
        """
        Parameters for creating a trajectory from a spline
        """
        self.spline_max_neval = 200 # maxmium number of evaluations
        self.spline_min_neval = 10 # minimum number of evaluations
        self.spline_grain = 1.5 # 1 trajectory point for every this many units pixel distance

    def set_image_model_params(self):
        """
        Max/min noise parameters for image model
        """
        self.max_blur_sigma = torch.tensor(16, dtype=torch.float) # min/max blur sigma
        self.min_blur_sigma = torch.tensor(0.5, dtype=torch.float) # min/max blur sigma
        self.max_epsilon = torch.tensor(0.5, dtype=torch.float) # min/max pixel epsilon
        self.min_epsilon = torch.tensor(1e-4, dtype=torch.float) # min/max pixel epsilon

    def set_mcmc_params(self):
        """
        MCMC parameters
        """
        ## chain parameters
        self.mcmc_nsamp_type_chain = 200 # number of samples to take in the MCMC chain (for classif.)
        self.mcmc_nsamp_type_store = 10 # number of samples to store from this chain (for classif.)
        self.mcmc_nsamp_token_chain = 25 # for completion (we take last sample in this chain)

        ## mcmc proposal parameters
        self.mcmc_prop_gpos_sd = 1 # global position move
        self.mcmc_prop_shape_sd = 3/2 # shape move
        self.mcmc_prop_scale_sd = 0.0235 # scale move
        self.mcmc_prop_relmid_sd = 0.2168 # attach relation move
        self.mcmc_prop_relpos_mlty = 2 # multiply the sd of the standard position noise by this to propose new positions from prior

    def set_search_params(self):
        """
        Parameters of search algorithm (part of inference)
        """
        self.K = 5 # number of particles to use in search algorithm
        self.max_affine_scale_change = 2 # scale changes must be less than a factor of 2
        self.max_affine_shift_change = 50 # shift changes must less than this  

class Painter(nn.Module):
    def __init__(self, PM=None):
        super().__init__()
        if PM is None:
            PM = Parameters()
        self.ink_pp = PM.ink_pp
        self.ink_max_dist = PM.ink_max_dist
        self.register_buffer('index_mat',
                             torch.arange(PM.imsize[0]*PM.imsize[1]).view(PM.imsize))
        self.register_buffer('space_flip', torch.tensor([-1., 1.]))
        self.imsize = PM.imsize

    @property
    def device(self):
        return self.index_mat.device

    @property
    def is_cuda(self):
        return self.index_mat.is_cuda

    def space_motor_to_img(self, stk):
        return torch.flip(stk, dims=[-1]) * self.space_flip

    def check_bounds(self, myt):
        xt = myt[:, 0]
        yt = myt[:, 1]
        x_out = (torch.floor(xt) < 0) | (torch.ceil(xt) >= self.imsize[0])
        y_out = (torch.floor(yt) < 0) | (torch.ceil(yt) >= self.imsize[1])
        out = x_out | y_out

        return out

    def seqadd(self, D, lind_x, lind_y, inkval):
        lind = self.index_mat[lind_x.long(), lind_y.long()]
        D = D.view(-1)
        D = D.scatter_add(0, lind, inkval)
        D = D.view(self.imsize)
        return D

    def add_stroke(self, pimg, stk):
    
        """
        Adds a stroke to an image based on given stroke data.
        
        This function transforms stroke coordinates from motor space to image space, checks for and
        ignores out-of-bounds points, computes the distance between consecutive points to adjust ink
        distribution, and finally distributes the ink over the image canvas based on calculated
        positions and ink amounts.
        
        Inputs:
            pimg: A tensor representing the current state of the image, where ink is to be added.
            stk: A tensor containing the stroke data in motor space. Each row represents a point
                 in the stroke with its coordinates (x, y).
        
        Outputs:
            pimg: The updated image tensor after adding the stroke.
            ink_off_page: A boolean indicating if any part of the stroke went off the canvas.
        """
        
        # Convert stroke coordinates from motor space to image space
        stk = self.space_motor_to_img(stk)
        
        # Check if stroke points are within the image bounds
        # 'out' is a boolean array indicating points outside the bounds
        out_of_bounds = self.check_bounds(stk)
        ink_off_page = out_of_bounds.any()
            
        # If all points are out of bounds, return the original image
        if out_of_bounds.all():
            return pimg, ink_off_page
            
        # Filter out points that are out of bounds
        stk = stk[~out_of_bounds]
        
        # Compute distances between consecutive points in the stroke
        if stk.shape[0] == 1:
            # If only one point, use a default ink amount per point
            ink_amount = stk.new_tensor(self.ink_pp)
        else:
            # Calculate distances and clamp to a maximum distance
            distances = torch.norm(stk[1:] - stk[:-1], dim=-1)
            distances = distances.clamp(None, self.ink_max_dist)
            distances = torch.cat([distances[:1], distances])
                
            # Adjust ink amounts based on distances
            ink_amount = (self.ink_pp / self.ink_max_dist) * distances
        
        # Ensure minimum ink amount is used for very short trajectories
        total_ink = torch.sum(ink_amount)
        if total_ink < 2.22e-6:
            ink_amount = (self.ink_pp / ink_amount.shape[0]) * torch.ones_like(ink_amount)
        elif total_ink < self.ink_pp:
            ink_amount *= (self.ink_pp / total_ink)
            
        # Assert that the total ink amount is within expected bounds
        assert torch.sum(ink_amount) > (self.ink_pp - 1e-4)
        
        # Calculate ink distribution over the neighboring four pixels
        x, y = stk[:, 0], stk[:, 1]
        xfloor, yfloor = torch.floor(x).detach(), torch.floor(y).detach()
        xceil, yceil = torch.ceil(x).detach(), torch.ceil(y).detach()
        x_c_ratio, y_c_ratio = x - xfloor, y - yfloor
        x_f_ratio, y_f_ratio = 1 - x_c_ratio, 1 - y_c_ratio
            
        # Linear indices and ink values for neighboring pixels
        lind_x = torch.cat([xfloor, xceil, xfloor, xceil])
        lind_y = torch.cat([yfloor, yfloor, yceil, yceil])
        ink_values = torch.cat([
            ink_amount * x_f_ratio * y_f_ratio,
            ink_amount * x_c_ratio * y_f_ratio,
            ink_amount * x_f_ratio * y_c_ratio,
            ink_amount * x_c_ratio * y_c_ratio
        ])
        
        # Paint the image by distributing ink to calculated pixels
        pimg = self.seqadd(pimg, lind_x, lind_y, ink_values)
        
        return pimg, ink_off_page

    def draw(self, pimg, stroke):
        # Since we are dealing with a single stroke, directly call add_stroke
        pimg, _ = self.add_stroke(pimg, stroke)
        return pimg
    
    def forward(self, stroke):
        # Ensures the input is on CPU if necessary
        assert not self.is_cuda
        stroke = drawings_to_cpu(stroke)
        pimg = torch.zeros(*self.imsize)
        pimg = self.draw(pimg, stroke)
    
        return pimg

class BroadenAndBlur(nn.Module):
    def __init__(self, blur_sigma=0.5, epsilon=0., blur_fsize=None, PM=None):
        super().__init__()
        if PM is None:
            PM = Parameters()
        if blur_fsize is None:
            blur_fsize = PM.fsize
        assert blur_fsize % 2 == 1, 'blur conv filter size must be odd'
        self.register_buffer('H_broaden', broaden_filter(PM.ink_a, PM.ink_b))
        self.register_buffer('H_blur', blur_filter(blur_fsize, blur_sigma))
        self.nbroad = PM.ink_ncon
        self.blur_pad = blur_fsize // 2
        self.blur_sigma = blur_sigma
        self.blur_fsize = blur_fsize
        self.epsilon = epsilon

    def forward(self, x, blur_sigma=None, epsilon=None):
        """
        Parameters
        ----------
        x : torch.Tensor
            [H,W] pre-conv image probabilities for a single stroke
        blur_sigma : float | None
            Amount of blur. 'None' means use value from __init__ call

        Returns
        -------
        x : torch.Tensor
            [H,W] post-conv image probabilities for a single stroke
        """
    
        if self.is_cuda:
            x = x.cuda()

        if blur_sigma is None:
            H_blur = self.H_blur
        else:
            blur_sigma = check_float_tensor(blur_sigma, self.device)  # Ensure this function exists or implement it
            H_blur = blur_filter(self.blur_fsize, blur_sigma, device=self.device)

        if epsilon is None:
            epsilon = self.epsilon
        else:
            epsilon = check_float_tensor(epsilon, self.device)  # Ensure this function exists or implement it

        # Apply broaden
        for i in range(self.nbroad):
            x = F.conv2d(x.unsqueeze(0).unsqueeze(0), self.H_broaden, padding=1).squeeze(0)
        x = F.hardtanh(x, 0., 1.)
        # Return if no blur is needed
        if blur_sigma == 0:
            return x
        # Apply blur
        for i in range(2):
            x = F.conv2d(x.unsqueeze(0).unsqueeze(0), H_blur, padding=self.blur_pad).squeeze(0)
        x = F.hardtanh(x, 0., 1.)
        # Apply pixel noise
        if epsilon > 0:
            x = (1 - epsilon) * x + epsilon * (1 - x)
        return x

class Renderer(nn.Module):
    def __init__(self, blur_sigma=0.5, epsilon=0., blur_fsize=None, PM=None):
        super().__init__()
        if PM is None:
            PM = Parameters()
        self.painter = Painter(PM)
        self.broaden_and_blur = BroadenAndBlur(blur_sigma, epsilon, blur_fsize, PM)

    def cuda(self, device=None):
        self.painter = self.painter.cpu()
        self.broaden_and_blur = self.broaden_and_blur.cuda(device)
        return self

    def forward(self, drawing, blur_sigma=None, epsilon=None):
        """
        Render the drawing by converting the stroke to image ink
        and then applying broaden & blur filters

        Parameters
        ----------
        drawing : torch.Tensor
            Input drawing. The drawing is a single tensor representing the stroke.
        blur_sigma : float | None
            Sigma parameter for blurring. Only used for adaptive blurring.
            Default 'None' means use the blur_sigma from __init__() call

        Returns
        -------
        pimg : torch.Tensor
            [H,W] Pre-conv image probability
        """

        #################################################
        ## TODO for students: fill in the missing variables ##
        # Fill out function and remove
        raise NotImplementedError("Student exercise: fill in the missing variables")
        #################################################
        
        if isinstance(drawing, list): # Ensure drawing is in the expected format (a single stroke)
            drawing = drawing[0]  # Assuming the drawing is the first element if wrapped in a list

        # Process the drawing
        pimg = self.painter([...])  # Painter expects a list of strokes
        pimg = self.broaden_and_blur(..., blur_sigma, epsilon)  # Apply broadening and blurring
        return pimg