In [1]:
import torch
import numpy as np
from sklearn.manifold import TSNE

from torch.utils.tensorboard import SummaryWriter  
import os
import io
import PIL.Image
from matplotlib import font_manager

def create_tsne_visualization_tensorboard(model, word_to_index, index_to_word, words_to_visualize=None,
                                         perplexity=30, n_iter=1000, random_state=42,
                                         log_dir="runs/sanskrit_tsne", font_path=None):
    embedding_matrix = model.hidden.weight.data.T  
    if words_to_visualize is None:
        words_to_visualize = list(word_to_index.keys())
    else:
        words_to_visualize = [word for word in words_to_visualize if word in word_to_index]
    indices_to_visualize = [word_to_index[word] for word in words_to_visualize]
    embeddings_to_visualize = embedding_matrix[indices_to_visualize]  
    tsne = TSNE(n_components=3, random_state=random_state, perplexity=perplexity, n_iter=n_iter)
    tsne_results = tsne.fit_transform(embeddings_to_visualize.cpu().numpy())  
    writer = SummaryWriter(log_dir)
    writer.add_embedding(embeddings_to_visualize,
                        metadata=words_to_visualize,
                        tag="sanskrit_word_embeddings")
    print(f"Saving TensorBoard data to {log_dir}.  Open TensorBoard to visualize the embeddings.")
    writer.close()

def create_sprite_image(dict_map, image_width=32, image_height=32):
    num_images = len(dict_map)
    n_cols = int(np.ceil(np.sqrt(num_images)))
    n_rows = int(np.ceil(num_images / n_cols))
    sprite_image = np.full((n_rows * image_height, n_cols * image_width, 4), [255, 255, 255, 0], dtype=np.uint8) 
    labels = []
    font_path = "Arial Unicode.ttf"  
    try:
        font_manager.findfont("Arial Unicode")
    except:
        font_path = "FreeSerif.ttf" 
    try:
        font_manager.findfont("FreeSerif")
    except:
        print("No appropriate font found.  Please install a Unicode font (e.g., Arial Unicode MS, FreeSerif) for proper Sanskrit display.")
        return None, None

    try:
        font = ImageFont.truetype(font_path, 20)  
    except OSError as e:
        print(f"Error loading font: {e}")
        print("Please ensure the font file is available at the specified path.")
        return None, None

    for i, (index, text) in enumerate(dict_map.items()):
        row = i // n_cols
        col = i % n_cols
        x1 = col * image_width
        y1 = row * image_height
        x2 = x1 + image_width
        y2 = y1 + image_height
        img = PIL.Image.new('RGBA', (image_width, image_height), color=(255, 255, 255, 0))  
        draw = ImageDraw.Draw(img)
        draw.text((0, 0), text, font=font, fill=(0, 0, 0, 255))  
        sprite_image[y1:y2, x1:x2] = np.array(img)
        labels.append(text)

    return sprite_image, labels

def visualize_embeddings(model, word_to_index, index_to_word, log_dir="runs/embedding_visualization",
                         max_words=1000, image_width=32, image_height=32):
    writer = SummaryWriter(log_dir)

    embedding_matrix = model.hidden.weight.data.T

    if len(word_to_index) > max_words:
        print(f"Visualizing a subset of {max_words} words from the vocabulary.")
        word_to_index_subset = dict(list(word_to_index.items())[:max_words])
        index_to_word_subset = {v: k for k, v in word_to_index_subset.items()}
        indices_to_visualize = list(word_to_index_subset.values())
        embeddings_to_visualize = embedding_matrix[indices_to_visualize]
        metadata = list(word_to_index_subset.keys())

    else:
        embeddings_to_visualize = embedding_matrix
        metadata = list(word_to_index.keys())

    sprite, labels = create_sprite_image(index_to_word, image_width, image_height)

    if sprite is not None:  
        writer.add_image('sanskrit_word_embeddings_sprite', sprite, dataformats='NHWC')
        writer.add_embedding(embeddings_to_visualize,
                            metadata=labels,  
                            label_img=sprite.reshape(-1, image_height, image_width, 4),
                            tag='sanskrit_word_embeddings')
        print(f"Saving TensorBoard data to {log_dir}.  Open TensorBoard to visualize the embeddings.")
    else:
        writer.add_embedding(embeddings_to_visualize,
                            metadata=labels,  
                            tag='sanskrit_word_embeddings')
        print(f"Saving TensorBoard data to {log_dir}.  Open TensorBoard to visualize the embeddings (without sprite).")

    writer.close()

In [2]:
import os
import re
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
text_dir = 'Sans_dataset'

In [3]:
all_text = ""
for file_name in os.listdir(text_dir):
    if file_name.endswith(".txt"):
        with open(os.path.join(text_dir, file_name), 'r', encoding='utf-8') as f:
            all_text += f.read() + "\n"

In [4]:
all_text = re.sub(r'\s+', ' ', all_text).strip()

In [5]:
tokens = all_text.split()
len(tokens)

1028452

In [6]:
word_counts = Counter(tokens)

In [7]:
min_freq = 7  
vocab = [word for word, freq in word_counts.items() if freq >= min_freq]



In [8]:
vocab_size = len(vocab)

In [9]:
word2index = {word: i for i, word in enumerate(vocab)}
index2word = {i: word for word, i in word2index.items()}

In [10]:
filtered_tokens = [w for w in tokens if w in word2index]
data = []
window_size = 2

In [11]:
class Word2VecScratch(nn.Module):
  def __init__(self, vocab_size, embedding_dim):
    super(Word2VecScratch, self).__init__()
    self.hidden = nn.Linear(vocab_size, embedding_dim, bias = False)
    self.output = nn.Linear(embedding_dim, vocab_size, bias = False)
  def forward(self, x):
    x = self.hidden(x)
    x = self.output(x)
    return x

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Word2VecScratch(vocab_size, 500).to(device)

In [13]:
state_dict = torch.load('best_model.pth')
model.load_state_dict(state_dict)


  state_dict = torch.load('best_model.pth')


<All keys matched successfully>

In [14]:
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont

visualize_embeddings(model, word2index, index2word, log_dir="runs/sanskrit_embeddings_demo", max_words=vocab_size)

create_tsne_visualization_tensorboard(model, word2index, index2word, log_dir="runs/sanskrit_tsne_demo", words_to_visualize=list(word2index.keys()))


findfont: Font family ['Arial Unicode'] not found. Falling back to DejaVu Sans.
findfont: Font family ['FreeSerif'] not found. Falling back to DejaVu Sans.


Error loading font: cannot open resource
Please ensure the font file is available at the specified path.
Saving TensorBoard data to runs/sanskrit_embeddings_demo.  Open TensorBoard to visualize the embeddings (without sprite).




Saving TensorBoard data to runs/sanskrit_tsne_demo.  Open TensorBoard to visualize the embeddings.


In [15]:
index2word[7664]

'सर्वात्मने'

In [16]:
index2word[7649]

'गोविन्दाय'