In [None]:
"""
Miscellaneous problems
"""

import torch

# memory efficient kth element look up
def kth_finder(x, k):
    topk = torch.topk(x, k, dim=2, largest=True, sorted=True).values
    kth = topk[:, :, -1]
    return kth

x = torch.rand(3, 5, 5)
print("Original tensor:\n")
print(x)
print("\n")

x = kth_finder(x, 3)
print("Kth largest element tensor:\n")
print(x)
print("\n")

# operations & experimenting with einsum

# dot product of two matrices (convolution)
a = torch.arange(9).reshape(3,3)
b = torch.arange(9, 18).reshape(3,3)

x = torch.einsum('ij, ij ->', [a,b])
print("Dot product of two matrices:\n")
print(a)
print(b)
print("\nOutput:")
print(x)
print("\n")

# combination of transposition and matrix multiplication

a = torch.arange(36).reshape(6, 2, 3)
b = torch.arange(72).reshape(6, 4, 3)
def transpose_mult(a, b):
    x = torch.einsum('...ik, ...jk -> ...ij', [a, b])
    return x
x = transpose_mult(a,b)
print("Tranpose and then batch multiplcation:\n")
print(x)


Original tensor:

tensor([[[0.0396, 0.9018, 0.1958, 0.6960, 0.4864],
         [0.0659, 0.6485, 0.2563, 0.0751, 0.3235],
         [0.9859, 0.1446, 0.2228, 0.5228, 0.3609],
         [0.9767, 0.2303, 0.2374, 0.8233, 0.9782],
         [0.9460, 0.2900, 0.2519, 0.9629, 0.1046]],

        [[0.6014, 0.6610, 0.4265, 0.3057, 0.1126],
         [0.3693, 0.2039, 0.5459, 0.9247, 0.5680],
         [0.5119, 0.1395, 0.0062, 0.8578, 0.5855],
         [0.3802, 0.3209, 0.4346, 0.1893, 0.8478],
         [0.6146, 0.4175, 0.9602, 0.8769, 0.2256]],

        [[0.8879, 0.4005, 0.4611, 0.6120, 0.3403],
         [0.2623, 0.6294, 0.5408, 0.4118, 0.9211],
         [0.1276, 0.1643, 0.5981, 0.1355, 0.7518],
         [0.6479, 0.2521, 0.1062, 0.7674, 0.8845],
         [0.9333, 0.1047, 0.0400, 0.4706, 0.8423]]])


Kth largest element tensor:

tensor([[0.4864, 0.2563, 0.3609, 0.8233, 0.2900],
        [0.4265, 0.5459, 0.5119, 0.3802, 0.6146],
        [0.4611, 0.5408, 0.1643, 0.6479, 0.4706]])


Dot product of two matrices