In [None]:
using Plots
using LinearAlgebra
using FFTW

include("modulation.jl")
include("bits.jl")

In [None]:
# parameters
Nfft = 16    # number of ofdm subcarriers
Nblocks = 4  # number of ofdm blocks
modOrder = 64
bps = Int(log2(modOrder)) # bits per symbol
numBits = bps * Nfft * Nblocks
modulator = QAM(modOrder)

SNRdB = 30
SNRlin = 10^(SNRdB/10)
N0 = modulator.avgEb/SNRlin

# generate payload data
bitstream = generateBitstream(numBits)
symstream = bitstream2symbolstream(bitstream, bps)
symstreamGray = [ encodeGray(s) for s in symstream ]
symbols = modulate(modulator, symstreamGray)

X = reshape(symbols, (Nfft, Nblocks))
#X = ones(Nfft, Nblocks)

In [None]:
# generate channel taps

Ntaps = 10# number of channel taps

Ncp = 15  # cyclix prefix length
Nsps = Nfft + Ncp # number of samples per block
Nsamples = Nsps*Nblocks # total number of time samples

Hcir = zeros(Nsamples, Nfft)
idxTapDelays = [1, 3, 5, 6, 10]
powTapDelaysdB = [0, -3, -5, -7, -9] # pdp
powTapDelays = [ 10^(p/10) for p in powTapDelaysdB ]

Hcir = zeros(Complex, Nsamples, Nfft)
for ii = 1:length(idxTapDelays)
    
    h = sqrt(powTapDelays[ii]/2) * ( randn(Float64, 1) + im*randn(Float64, 1) )
    
    #Hcir[:,idxTapDelays[ii]] = sqrt(powTapDelays[ii]/2) * ( randn(Float64, Nsamples, 1) + im*randn(Float64, Nsamples, 1) )
    Hcir[:,idxTapDelays[ii]] = h[1]*ones(Nsamples,1)
end

# convolution
H0 = zeros(Complex, Nblocks, Nsps, Nsps)
H1 = zeros(Complex, Nblocks, Nsps, Nsps)

for kk = 1:Nblocks
    for nn = 1:Nsps
        for mm = 1:Nsps
        
            ii = (kk-1)*Nsps + nn
            jj0 = nn - mm
            jj1 = Nsps + nn - mm
            
            if jj0 < 0 || jj0 > Ntaps - 1
                H0[kk, nn, mm] = 0;
            else
                H0[kk, nn, mm] = Hcir[ii, jj0+1];
            end
            
            if jj1 < 0 || jj1 > Ntaps - 1
                H1[kk, nn, mm] = 0;
            else
                H1[kk, nn, mm] = Hcir[ii, jj1+1];
            end
        end
    end
                
end

p0 = heatmap(abs.(Hcir), yflip=true)
p1 = heatmap(abs.(H0[2,:,:]), yflip=true)
p2 = heatmap(abs.(H1[2,:,:]), yflip=true)
p3 = heatmap(abs.(H0[2,:,:]+H1[2,:,:]), yflip=true)
plot(p0, p1, p2, p3, layout=4)


In [None]:
# ofdm transmit processing

Wfft = (1/sqrt(Nfft))*fft(I(Nfft),1);

# CP matrices
Isc = Matrix{Float64}(I(Nfft));
Icp = Isc[Nfft-Ncp+1:end,:];

Tcp = vcat(Icp, Isc);                # insert CP
Rcp = hcat(zeros(Nfft, Ncp), Matrix{Float64}(I(Nfft)) );    # remove CP

Xcp = Tcp*(Wfft')*X;

# channel
Ycp = zeros(Complex, Nsps, Nblocks)
for kk = 1:Nblocks
    
    noi = (sqrt(N0)/sqrt(2)).*(randn(Nsps) + im*randn(Nsps))
    
    if kk == 1
        Ycp[:,kk] = H0[kk,:,:]*Xcp[:,kk] + noi
    else
        Ycp[:,kk] = H0[kk,:,:]*Xcp[:,kk] + H1[kk,:,:]*Xcp[:,kk-1] + noi
    end
end

# receive ofdm processing
Y = Wfft*Rcp*Ycp;

# equalization
Xest = zeros(Complex, Nfft, Nblocks)
for kk = 1:Nblocks
    Hest = Wfft*Rcp*H0[kk,:,:]*Tcp*(Wfft')
    equalizer = inv(diagm(diag(Hest)))
    Xest[:,kk] = equalizer*Y[:,kk]
end

symbolsEst = vec(Xest)

symstreamEst = demodulate(modulator, symbolsEst)
symstreamEstGray = [ decodeGray(s) for s in symstreamEst ]
bitstreamEst = symbolstream2bitstream(symstreamEstGray, bps)

bitErrors = sum(bitstream .!= bitstreamEst)

ber = bitErrors/numBits