In [None]:
@typechecked
def random_ortho_matrix_gen(dim_a: int, dim_b: int) -> Iterator[TensorType["dim_a", "dim_b"]]:
    assert dim_a >= dim_b, f"Assuming we want projection matrices (broad and short), got dims {dim_a} x {dim_b}"
    while True:
        m = ortho_group.rvs(dim=dim_a)
        for i in range(dim_a // dim_b):
            yield torch.Tensor(m[i*dim_b : (i+1)*dim_b])

@typechecked
def prior_sampler(name: str, batch_size: int, feature_dim: int) -> Callable[[], 
                                                                            TensorType["batch_size", "feature_dim"]]:
    """
    Constructs a sampling function from a named distribution. E.g., with `name=="Uniform hypersphere"`, 
    the resulting function samples `batch_size` vectors of length `feature_dim` on a uniform hypersphere.
    """
    @typechecked
    def hypersphere_sampler() -> TensorType['batch_size', 'feature_dim']:
        X = torch.normal(mean=0, std=1, size=(batch_size, feature_dim))
        return X / LA.norm(X, dim=1).unsqueeze(1)
    
    d : Dict[str,  Callable[[], TensorType["batch_size", "feature_dim"]]] = {
         "Uniform hypersphere": hypersphere_sampler, 
         "Uniform hypercube": lambda: torch.rand(size=(batch_size, feature_dim)),
         "Normal distribution": lambda: torch.normal(mean=0, std=1, size=(batch_size, feature_dim)),
        }
    if name not in d:
        raise ValueError(f"Distr '{name}' not in {d.keys()}")
    
    return d[name]

class SWD_contrastiveloss(nn.Module):
    @typechecked
    def __init__(self, batch_size: int, feature_dim: int, prior_name: str,
                 normalize_before_align: bool = True,
                 SWD_dim: int=-1, SWD_lambda: float = 1.):
        super(SWD_contrastiveloss, self).__init__()
        self.batch_size = batch_size
        self.feature_dim = feature_dim
        #self.temperature = temperature
        self.normalize_before_align = normalize_before_align
        self.sample_prior = prior_sampler(prior_name, batch_size=2*batch_size, feature_dim=feature_dim)
        self.ortho_matrix_gen = random_ortho_matrix_gen(feature_dim, SWD_dim)
        self.lmbda = SWD_lambda
        self.SWD_dim = SWD_dim
        
    @typechecked
    def forward(self, zi: TensorType["batch_size", "feature_dim"],
                zj: TensorType["batch_size", "feature_dim"]):
        # Following "Algorithm 1" in the paper
        n : int = self.batch_size
        d = zi.size(dim = -1)
        
        # Project zi/zj onto hypersphere (i.e. normalize).
        if self.normalize_before_align:
            zi = zi / LA.norm(zi, dim=1).unsqueeze(1)
            zj = zj / LA.norm(zj, dim=1).unsqueeze(1)
        loss_align = ((zi - zj)**2).sum() / (n*d)
        Z : TensorType[2*self.batch_size, self.feature_dim] = torch.cat((zi, zj), dim=0).to(DEVICE)
        P : TensorType[2*self.batch_size, self.feature_dim] = self.sample_prior().to(DEVICE)
        W : TensorType[self.feature_dim, self.SWD_dim] = next(self.ortho_matrix_gen).to(DEVICE)
        
        
        H_perp, P_perp = Z @ W, P @ W 
        #loss_distr = torch.Tensor(0.) #Getting compiler error if we do like that, solution done below.
        loss_distr = 0 #commented out torch.Tensor since we are supposed to get a singular value anyway.
        for j in range(self.SWD_dim):
            hj, pj = H_perp[:, j], P_perp[:, j]
            hj, _ = torch.sort(hj)
            pj, _ = torch.sort(pj)
            loss_distr = loss_distr + ((hj - pj)**2).sum()
        loss_distr = loss_distr / (self.feature_dim * self.SWD_dim)
        return loss_align + self.lmbda * loss_distr