In [2]:
# Train the semantic communication channel and save encoder decoder into /models/
import torch
from torch.utils.data import DataLoader
from codes.channel.proposed_model import SemanticCommunicationChannel
from codes.train_semantic import train_semantic_communication_system
from codes.calculate.utils import load_images, save_model
import math, os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

# Prepare hyperparameters
num_epochs = 4000
batch_size = 64
lr = 0.0005
train_rate = 0.8
test_rate = 1 - train_rate
train_snr = 15

# Load images to train the semantic communication channel
TRAIN_DIR = "data/coco_1000/train/"
file_count = sum(1 for file in os.listdir(TRAIN_DIR) if file.endswith('.jpg'))
train_size = math.ceil(file_count * train_rate)
test_size = math.floor(file_count * test_rate)
print(file_count)

images = load_images(TRAIN_DIR)
num_images = images.size(0)
image_channels = images.size(1)
image_height = images.size(2)
image_width = images.size(3)

# Reshape images tensor to fit the semantic encoder input shape
images = images.view(num_images, image_channels, image_height, image_width)

# Shuffle and split the dataset into training and test sets
dataset = torch.utils.data.TensorDataset(images)

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Prepare train channel
channel = SemanticCommunicationChannel()
encoder, decoder = train_semantic_communication_system(channel=channel, dataloader=train_loader, device=device, num_epochs=num_epochs, train_snr=train_snr, lr=lr)
save_model(encoder, "encoder_sc5")
save_model(decoder, "decoder_sc5")

1000
