In [1]:
import torch
from models.backbone import Backbone
from models.recognizer import CRNN

device = "cuda" if torch.cuda.is_available() else "cpu"

backbone = Backbone().to(device)
recognizer = CRNN().to(device)

dummy = torch.randn(1, 3, 224, 224).to(device)
feat = backbone(dummy)
log_probs = recognizer(feat)

print("Feature map:", feat.shape)
print("Log probs:", log_probs.shape)




Feature map: torch.Size([1, 512, 7, 7])
Log probs: torch.Size([1, 7, 38])


In [2]:
import torch
from utils.reliability import detection_reliability, recognition_reliability, combine_reliability

# fake detection confidence
det_conf = torch.tensor([0.2, 0.5, 0.8])

# fake recognition output
log_probs = torch.randn(3, 10, 38).log_softmax(dim=-1)

det_w = detection_reliability(det_conf)
rec_w = recognition_reliability(log_probs)
final_w = combine_reliability(det_w, rec_w)

print("Detection weight:", det_w)
print("Recognition weight:", rec_w)
print("Final weight:", final_w)


Detection weight: tensor([0.0474, 0.5000, 0.9526])
Recognition weight: tensor([0.0419, 0.0421, 0.0441])
Final weight: tensor([0.0020, 0.0210, 0.0420])


In [3]:
import torch
from losses.weighted_loss import WeightedCTCLoss

criterion = WeightedCTCLoss()

# fake recognizer output
T, B, V = 7, 3, 38
log_probs = torch.randn(T, B, V).log_softmax(dim=-1)

# fake targets
targets = torch.randint(1, V, (12,))
target_lengths = torch.tensor([4, 4, 4])
input_lengths = torch.tensor([T, T, T])

# fake weights
weights = torch.tensor([0.1, 0.5, 1.0])

loss = criterion(
    log_probs,
    targets,
    input_lengths,
    target_lengths,
    weights
)

print("Weighted CTC loss:", loss.item())


Weighted CTC loss: 11.777215003967285


In [4]:
import torch
from models.backbone import Backbone
from models.recognizer import CRNN
from models.teacher_student import TeacherStudentSSL
from losses.weighted_loss import WeightedCTCLoss

device = "cuda" if torch.cuda.is_available() else "cpu"

backbone = Backbone().to(device)
recognizer = CRNN().to(device)
criterion = WeightedCTCLoss()

ssl_model = TeacherStudentSSL(
    backbone,
    recognizer,
    criterion
).to(device)

# fake batch
images = torch.randn(2, 3, 224, 224).to(device)
det_conf = torch.tensor([0.3, 0.8]).to(device)

T = 7
targets = torch.randint(1, 38, (10,)).to(device)
input_lengths = torch.tensor([T, T]).to(device)
target_lengths = torch.tensor([5, 5]).to(device)

loss, weights = ssl_model(
    images,
    det_conf,
    targets,
    input_lengths,
    target_lengths
)

print("SSL loss:", loss.item())
print("Reliability weights:", weights)


SSL loss: 0.29958102107048035
Reliability weights: tensor([0.0031, 0.0251], device='cuda:0')


In [5]:
s_param = next(ssl_model.student_recognizer.parameters())
t_param = next(ssl_model.teacher_recognizer.parameters())

print("Same object:", s_param.data_ptr() == t_param.data_ptr())


Same object: False


In [6]:
before = next(ssl_model.teacher_recognizer.parameters()).clone()

ssl_model.update_teacher()

after = next(ssl_model.teacher_recognizer.parameters())

print("Teacher param change magnitude:",
      torch.norm(after - before).item())


Teacher param change magnitude: 2.960731251278048e-07


In [7]:
loss, weights = ssl_model(
    images,
    det_conf,
    targets,
    input_lengths,
    target_lengths
)

print("Loss:", loss.item())
print("Weights:", weights)


Loss: 0.29958200454711914
Weights: tensor([0.0031, 0.0251], device='cuda:0')


  result = _VF.lstm(


In [8]:
import torch


def perturb_images(images, noise_std=0.02):
    """
    Simple pixel-level perturbation
    (acts as proxy for localization noise)
    """
    noise = torch.randn_like(images) * noise_std
    return (images + noise).clamp(0, 1)


In [9]:
loss, weights = ssl_model(
    images,
    det_conf,
    targets,
    input_lengths,
    target_lengths
)

print("Total loss:", loss.item())
print("Weights:", weights)


Total loss: 0.2995823621749878
Weights: tensor([0.0031, 0.0251], device='cuda:0')


In [10]:
from data.ic15_subset import IC15Subset

dataset = IC15Subset(
    image_dir="D:\\semiETS stuffs\\semiets_scratch\\data\\ic15\\images",
    annotation_json="D:\\semiETS stuffs\\semiets_scratch\\data\\ic15\\ic15_subset.json",
    vocab="0123456789abcdefghijklmnopqrstuvwxyz",
    max_samples=5
)

sample = dataset[0]
print(sample["images"].shape)
print(sample["targets"])


torch.Size([3, 224, 224])
tensor([17, 15, 24, 11, 34, 19, 29, 30, 18, 15, 11, 30, 28, 15])


In [11]:
from data.ic15_subset import IC15Subset, ic15_collate_fn
print("Import successful")


ImportError: cannot import name 'ic15_collate_fn' from 'data.ic15_subset' (d:\semiETS stuffs\semiets_scratch\data\ic15_subset.py)